diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 7f7a8a2e4de24..a329e14f25aeb 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -12,7 +12,8 @@ export("print.jobj") # MLlib integration exportMethods("glm", - "predict") + "predict", + "summary") # Job group lifecycle management methods export("setJobGroup", diff --git a/R/pkg/R/backend.R b/R/pkg/R/backend.R index 2fb6fae55f28c..49162838b8d1a 100644 --- a/R/pkg/R/backend.R +++ b/R/pkg/R/backend.R @@ -110,6 +110,8 @@ invokeJava <- function(isStatic, objId, methodName, ...) { # TODO: check the status code to output error information returnStatus <- readInt(conn) - stopifnot(returnStatus == 0) + if (returnStatus != 0) { + stop(readString(conn)) + } readObject(conn) } diff --git a/R/pkg/R/client.R b/R/pkg/R/client.R index 6f772158ddfe8..c811d1dac3bd5 100644 --- a/R/pkg/R/client.R +++ b/R/pkg/R/client.R @@ -48,7 +48,7 @@ generateSparkSubmitArgs <- function(args, sparkHome, jars, sparkSubmitOpts, pack jars <- paste("--jars", jars) } - if (packages != "") { + if (!identical(packages, "")) { packages <- paste("--packages", packages) } diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 836e0175c391f..a3a121058e165 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -254,8 +254,10 @@ setGeneric("flatMapValues", function(X, FUN) { standardGeneric("flatMapValues") # @rdname intersection # @export -setGeneric("intersection", function(x, other, numPartitions = 1) { - standardGeneric("intersection") }) +setGeneric("intersection", + function(x, other, numPartitions = 1) { + standardGeneric("intersection") + }) # @rdname keys # @export @@ -489,9 +491,7 @@ setGeneric("sample", #' @rdname sample #' @export setGeneric("sample_frac", - function(x, withReplacement, fraction, seed) { - standardGeneric("sample_frac") - }) + function(x, withReplacement, fraction, seed) { standardGeneric("sample_frac") }) #' @rdname saveAsParquetFile #' @export @@ -553,8 +553,8 @@ setGeneric("withColumn", function(x, colName, col) { standardGeneric("withColumn #' @rdname withColumnRenamed #' @export -setGeneric("withColumnRenamed", function(x, existingCol, newCol) { - standardGeneric("withColumnRenamed") }) +setGeneric("withColumnRenamed", + function(x, existingCol, newCol) { standardGeneric("withColumnRenamed") }) ###################### Column Methods ########################## diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R index 258e354081fc1..efddcc1d8d71c 100644 --- a/R/pkg/R/mllib.R +++ b/R/pkg/R/mllib.R @@ -27,7 +27,7 @@ setClass("PipelineModel", representation(model = "jobj")) #' Fits a generalized linear model, similarly to R's glm(). Also see the glmnet package. #' #' @param formula A symbolic description of the model to be fitted. Currently only a few formula -#' operators are supported, including '~' and '+'. +#' operators are supported, including '~', '+', '-', and '.'. #' @param data DataFrame for training #' @param family Error distribution. "gaussian" -> linear regression, "binomial" -> logistic reg. #' @param lambda Regularization parameter @@ -71,3 +71,29 @@ setMethod("predict", signature(object = "PipelineModel"), function(object, newData) { return(dataFrame(callJMethod(object@model, "transform", newData@sdf))) }) + +#' Get the summary of a model +#' +#' Returns the summary of a model produced by glm(), similarly to R's summary(). +#' +#' @param model A fitted MLlib model +#' @return a list with a 'coefficient' component, which is the matrix of coefficients. See +#' summary.glm for more information. +#' @rdname glm +#' @export +#' @examples +#'\dontrun{ +#' model <- glm(y ~ x, trainingData) +#' summary(model) +#'} +setMethod("summary", signature(object = "PipelineModel"), + function(object) { + features <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", + "getModelFeatures", object@model) + weights <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", + "getModelWeights", object@model) + coefficients <- as.matrix(unlist(weights)) + colnames(coefficients) <- c("Estimate") + rownames(coefficients) <- unlist(features) + return(list(coefficients = coefficients)) + }) diff --git a/R/pkg/R/pairRDD.R b/R/pkg/R/pairRDD.R index ebc6ff65e9d0f..83801d3209700 100644 --- a/R/pkg/R/pairRDD.R +++ b/R/pkg/R/pairRDD.R @@ -202,8 +202,8 @@ setMethod("partitionBy", packageNamesArr <- serialize(.sparkREnv$.packages, connection = NULL) - broadcastArr <- lapply(ls(.broadcastNames), function(name) { - get(name, .broadcastNames) }) + broadcastArr <- lapply(ls(.broadcastNames), + function(name) { get(name, .broadcastNames) }) jrdd <- getJRDD(x) # We create a PairwiseRRDD that extends RDD[(Int, Array[Byte])], diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R index 76c15875b50d5..e83104f116422 100644 --- a/R/pkg/R/sparkR.R +++ b/R/pkg/R/sparkR.R @@ -22,7 +22,8 @@ connExists <- function(env) { tryCatch({ exists(".sparkRCon", envir = env) && isOpen(env[[".sparkRCon"]]) - }, error = function(err) { + }, + error = function(err) { return(FALSE) }) } @@ -153,7 +154,8 @@ sparkR.init <- function( .sparkREnv$backendPort <- backendPort tryCatch({ connectBackend("localhost", backendPort) - }, error = function(err) { + }, + error = function(err) { stop("Failed to connect JVM\n") }) @@ -264,7 +266,8 @@ sparkRHive.init <- function(jsc = NULL) { ssc <- callJMethod(sc, "sc") hiveCtx <- tryCatch({ newJObject("org.apache.spark.sql.hive.HiveContext", ssc) - }, error = function(err) { + }, + error = function(err) { stop("Spark SQL is not built with Hive support") }) diff --git a/R/pkg/inst/tests/test_client.R b/R/pkg/inst/tests/test_client.R index 30b05c1a2afcd..8a20991f89af8 100644 --- a/R/pkg/inst/tests/test_client.R +++ b/R/pkg/inst/tests/test_client.R @@ -30,3 +30,7 @@ test_that("no package specified doesn't add packages flag", { expect_equal(gsub("[[:space:]]", "", args), "") }) + +test_that("multiple packages don't produce a warning", { + expect_that(generateSparkSubmitArgs("", "", "", "", c("A", "B")), not(gives_warning())) +}) diff --git a/R/pkg/inst/tests/test_mllib.R b/R/pkg/inst/tests/test_mllib.R index 29152a11688a2..f272de78ad4a6 100644 --- a/R/pkg/inst/tests/test_mllib.R +++ b/R/pkg/inst/tests/test_mllib.R @@ -40,3 +40,22 @@ test_that("predictions match with native glm", { rVals <- predict(glm(Sepal.Width ~ Sepal.Length + Species, data = iris), iris) expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals) }) + +test_that("dot minus and intercept vs native glm", { + training <- createDataFrame(sqlContext, iris) + model <- glm(Sepal_Width ~ . - Species + 0, data = training) + vals <- collect(select(predict(model, training), "prediction")) + rVals <- predict(glm(Sepal.Width ~ . - Species + 0, data = iris), iris) + expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals) +}) + +test_that("summary coefficients match with native glm", { + training <- createDataFrame(sqlContext, iris) + stats <- summary(glm(Sepal_Width ~ Sepal_Length + Species, data = training)) + coefs <- as.vector(stats$coefficients) + rCoefs <- as.vector(coef(glm(Sepal.Width ~ Sepal.Length + Species, data = iris))) + expect_true(all(abs(rCoefs - coefs) < 1e-6)) + expect_true(all( + as.character(stats$features) == + c("(Intercept)", "Sepal_Length", "Species__versicolor", "Species__virginica"))) +}) diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index 62fe48a5d6c7b..61c8a7ec7d837 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -112,7 +112,8 @@ test_that("create DataFrame from RDD", { df <- jsonFile(sqlContext, jsonPathNa) hiveCtx <- tryCatch({ newJObject("org.apache.spark.sql.hive.test.TestHiveContext", ssc) - }, error = function(err) { + }, + error = function(err) { skip("Hive is not build with SparkSQL, skipped") }) sql(hiveCtx, "CREATE TABLE people (name string, age double, height float)") @@ -602,7 +603,8 @@ test_that("write.df() as parquet file", { test_that("test HiveContext", { hiveCtx <- tryCatch({ newJObject("org.apache.spark.sql.hive.test.TestHiveContext", ssc) - }, error = function(err) { + }, + error = function(err) { skip("Hive is not build with SparkSQL, skipped") }) df <- createExternalTable(hiveCtx, "json", jsonPath, "json") @@ -1000,6 +1002,11 @@ test_that("crosstab() on a DataFrame", { expect_identical(expected, ordered) }) +test_that("SQL error message is returned from JVM", { + retError <- tryCatch(sql(sqlContext, "select * from blah"), error = function(e) e) + expect_equal(grepl("Table Not Found: blah", retError), TRUE) +}) + unlink(parquetPath) unlink(jsonPath) unlink(jsonPathNa) diff --git a/R/run-tests.sh b/R/run-tests.sh index e82ad0ba2cd06..18a1e13bdc655 100755 --- a/R/run-tests.sh +++ b/R/run-tests.sh @@ -23,7 +23,7 @@ FAILED=0 LOGFILE=$FWDIR/unit-tests.out rm -f $LOGFILE -SPARK_TESTING=1 $FWDIR/../bin/sparkR --driver-java-options "-Dlog4j.configuration=file:$FWDIR/log4j.properties" $FWDIR/pkg/tests/run-all.R 2>&1 | tee -a $LOGFILE +SPARK_TESTING=1 $FWDIR/../bin/sparkR --conf spark.buffer.pageSize=4m --driver-java-options "-Dlog4j.configuration=file:$FWDIR/log4j.properties" $FWDIR/pkg/tests/run-all.R 2>&1 | tee -a $LOGFILE FAILED=$((PIPESTATUS[0]||$FAILED)) if [[ $FAILED != 0 ]]; then diff --git a/core/pom.xml b/core/pom.xml index 95f36eb348698..202678779150b 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -34,6 +34,11 @@ Spark Project Core http://spark.apache.org/ + + org.apache.avro + avro-mapred + ${avro.mapred.classifier} + com.google.guava guava @@ -281,7 +286,7 @@ org.tachyonproject tachyon-client - 0.6.4 + 0.7.0 org.apache.hadoop @@ -292,36 +297,12 @@ curator-recipes - org.eclipse.jetty - jetty-jsp - - - org.eclipse.jetty - jetty-webapp - - - org.eclipse.jetty - jetty-server - - - org.eclipse.jetty - jetty-servlet + org.tachyonproject + tachyon-underfs-glusterfs - junit - junit - - - org.powermock - powermock-module-junit4 - - - org.powermock - powermock-api-mockito - - - org.apache.curator - curator-test + org.tachyonproject + tachyon-underfs-s3 diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java index 1d460432be9ff..1aa6ba4201261 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java @@ -59,14 +59,14 @@ final class UnsafeShuffleExternalSorter { private final Logger logger = LoggerFactory.getLogger(UnsafeShuffleExternalSorter.class); - private static final int PAGE_SIZE = PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES; @VisibleForTesting static final int DISK_WRITE_BUFFER_SIZE = 1024 * 1024; - @VisibleForTesting - static final int MAX_RECORD_SIZE = PAGE_SIZE - 4; private final int initialSize; private final int numPartitions; + private final int pageSizeBytes; + @VisibleForTesting + final int maxRecordSizeBytes; private final TaskMemoryManager memoryManager; private final ShuffleMemoryManager shuffleMemoryManager; private final BlockManager blockManager; @@ -109,7 +109,10 @@ public UnsafeShuffleExternalSorter( this.numPartitions = numPartitions; // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided this.fileBufferSizeBytes = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024; - + this.pageSizeBytes = (int) Math.min( + PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES, + conf.getSizeAsBytes("spark.buffer.pageSize", "64m")); + this.maxRecordSizeBytes = pageSizeBytes - 4; this.writeMetrics = writeMetrics; initializeForWriting(); } @@ -272,7 +275,11 @@ void spill() throws IOException { } private long getMemoryUsage() { - return sorter.getMemoryUsage() + (allocatedPages.size() * (long) PAGE_SIZE); + long totalPageSize = 0; + for (MemoryBlock page : allocatedPages) { + totalPageSize += page.size(); + } + return sorter.getMemoryUsage() + totalPageSize; } private long freeMemory() { @@ -346,23 +353,23 @@ private void allocateSpaceForRecord(int requiredSpace) throws IOException { // TODO: we should track metrics on the amount of space wasted when we roll over to a new page // without using the free space at the end of the current page. We should also do this for // BytesToBytesMap. - if (requiredSpace > PAGE_SIZE) { + if (requiredSpace > pageSizeBytes) { throw new IOException("Required space " + requiredSpace + " is greater than page size (" + - PAGE_SIZE + ")"); + pageSizeBytes + ")"); } else { - final long memoryAcquired = shuffleMemoryManager.tryToAcquire(PAGE_SIZE); - if (memoryAcquired < PAGE_SIZE) { + final long memoryAcquired = shuffleMemoryManager.tryToAcquire(pageSizeBytes); + if (memoryAcquired < pageSizeBytes) { shuffleMemoryManager.release(memoryAcquired); spill(); - final long memoryAcquiredAfterSpilling = shuffleMemoryManager.tryToAcquire(PAGE_SIZE); - if (memoryAcquiredAfterSpilling != PAGE_SIZE) { + final long memoryAcquiredAfterSpilling = shuffleMemoryManager.tryToAcquire(pageSizeBytes); + if (memoryAcquiredAfterSpilling != pageSizeBytes) { shuffleMemoryManager.release(memoryAcquiredAfterSpilling); - throw new IOException("Unable to acquire " + PAGE_SIZE + " bytes of memory"); + throw new IOException("Unable to acquire " + pageSizeBytes + " bytes of memory"); } } - currentPage = memoryManager.allocatePage(PAGE_SIZE); + currentPage = memoryManager.allocatePage(pageSizeBytes); currentPagePosition = currentPage.getBaseOffset(); - freeSpaceInCurrentPage = PAGE_SIZE; + freeSpaceInCurrentPage = pageSizeBytes; allocatedPages.add(currentPage); } } diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java index 764578b181422..d47d6fc9c2ac4 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java @@ -129,6 +129,11 @@ public UnsafeShuffleWriter( open(); } + @VisibleForTesting + public int maxRecordSizeBytes() { + return sorter.maxRecordSizeBytes; + } + /** * This convenience method should only be called in test code. */ diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java index bf1bc5dffba78..4d7e5b3dfba6e 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java @@ -17,9 +17,7 @@ package org.apache.spark.util.collection.unsafe.sort; -import com.google.common.base.Charsets; -import com.google.common.primitives.Longs; -import com.google.common.primitives.UnsignedBytes; +import com.google.common.primitives.UnsignedLongs; import org.apache.spark.annotation.Private; import org.apache.spark.unsafe.types.UTF8String; @@ -30,81 +28,67 @@ public class PrefixComparators { private PrefixComparators() {} public static final StringPrefixComparator STRING = new StringPrefixComparator(); - public static final IntegralPrefixComparator INTEGRAL = new IntegralPrefixComparator(); - public static final FloatPrefixComparator FLOAT = new FloatPrefixComparator(); + public static final StringPrefixComparatorDesc STRING_DESC = new StringPrefixComparatorDesc(); + public static final LongPrefixComparator LONG = new LongPrefixComparator(); + public static final LongPrefixComparatorDesc LONG_DESC = new LongPrefixComparatorDesc(); public static final DoublePrefixComparator DOUBLE = new DoublePrefixComparator(); + public static final DoublePrefixComparatorDesc DOUBLE_DESC = new DoublePrefixComparatorDesc(); public static final class StringPrefixComparator extends PrefixComparator { @Override public int compare(long aPrefix, long bPrefix) { - // TODO: can done more efficiently - byte[] a = Longs.toByteArray(aPrefix); - byte[] b = Longs.toByteArray(bPrefix); - for (int i = 0; i < 8; i++) { - int c = UnsignedBytes.compare(a[i], b[i]); - if (c != 0) return c; - } - return 0; + return UnsignedLongs.compare(aPrefix, bPrefix); } - public long computePrefix(byte[] bytes) { - if (bytes == null) { - return 0L; - } else { - byte[] padded = new byte[8]; - System.arraycopy(bytes, 0, padded, 0, Math.min(bytes.length, 8)); - return Longs.fromByteArray(padded); - } - } - - public long computePrefix(String value) { - return value == null ? 0L : computePrefix(value.getBytes(Charsets.UTF_8)); + public static long computePrefix(UTF8String value) { + return value == null ? 0L : value.getPrefix(); } + } - public long computePrefix(UTF8String value) { - return value == null ? 0L : computePrefix(value.getBytes()); + public static final class StringPrefixComparatorDesc extends PrefixComparator { + @Override + public int compare(long bPrefix, long aPrefix) { + return UnsignedLongs.compare(aPrefix, bPrefix); } } - /** - * Prefix comparator for all integral types (boolean, byte, short, int, long). - */ - public static final class IntegralPrefixComparator extends PrefixComparator { + public static final class LongPrefixComparator extends PrefixComparator { @Override public int compare(long a, long b) { return (a < b) ? -1 : (a > b) ? 1 : 0; } + } - public final long NULL_PREFIX = Long.MIN_VALUE; + public static final class LongPrefixComparatorDesc extends PrefixComparator { + @Override + public int compare(long b, long a) { + return (a < b) ? -1 : (a > b) ? 1 : 0; + } } - public static final class FloatPrefixComparator extends PrefixComparator { + public static final class DoublePrefixComparator extends PrefixComparator { @Override public int compare(long aPrefix, long bPrefix) { - float a = Float.intBitsToFloat((int) aPrefix); - float b = Float.intBitsToFloat((int) bPrefix); - return Utils.nanSafeCompareFloats(a, b); + double a = Double.longBitsToDouble(aPrefix); + double b = Double.longBitsToDouble(bPrefix); + return Utils.nanSafeCompareDoubles(a, b); } - public long computePrefix(float value) { - return Float.floatToIntBits(value) & 0xffffffffL; + public static long computePrefix(double value) { + return Double.doubleToLongBits(value); } - - public final long NULL_PREFIX = computePrefix(Float.NEGATIVE_INFINITY); } - public static final class DoublePrefixComparator extends PrefixComparator { + public static final class DoublePrefixComparatorDesc extends PrefixComparator { @Override - public int compare(long aPrefix, long bPrefix) { + public int compare(long bPrefix, long aPrefix) { double a = Double.longBitsToDouble(aPrefix); double b = Double.longBitsToDouble(bPrefix); return Utils.nanSafeCompareDoubles(a, b); } - public long computePrefix(double value) { + public static long computePrefix(double value) { return Double.doubleToLongBits(value); } - - public final long NULL_PREFIX = computePrefix(Double.NEGATIVE_INFINITY); } } diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index 80b03d7e99e2b..866e0b4151577 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -20,6 +20,9 @@ import java.io.IOException; import java.util.LinkedList; +import scala.runtime.AbstractFunction0; +import scala.runtime.BoxedUnit; + import com.google.common.annotations.VisibleForTesting; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -41,10 +44,7 @@ public final class UnsafeExternalSorter { private final Logger logger = LoggerFactory.getLogger(UnsafeExternalSorter.class); - private static final int PAGE_SIZE = 1 << 27; // 128 megabytes - @VisibleForTesting - static final int MAX_RECORD_SIZE = PAGE_SIZE - 4; - + private final long pageSizeBytes; private final PrefixComparator prefixComparator; private final RecordComparator recordComparator; private final int initialSize; @@ -91,7 +91,19 @@ public UnsafeExternalSorter( this.initialSize = initialSize; // Use getSizeAsKb (not bytes) to maintain backwards compatibility for units this.fileBufferSizeBytes = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024; + this.pageSizeBytes = conf.getSizeAsBytes("spark.buffer.pageSize", "64m"); initializeForWriting(); + + // Register a cleanup task with TaskContext to ensure that memory is guaranteed to be freed at + // the end of the task. This is necessary to avoid memory leaks in when the downstream operator + // does not fully consume the sorter's output (e.g. sort followed by limit). + taskContext.addOnCompleteCallback(new AbstractFunction0() { + @Override + public BoxedUnit apply() { + freeMemory(); + return null; + } + }); } // TODO: metrics tracking + integration with shuffle write metrics @@ -147,7 +159,11 @@ public void spill() throws IOException { } private long getMemoryUsage() { - return sorter.getMemoryUsage() + (allocatedPages.size() * (long) PAGE_SIZE); + long totalPageSize = 0; + for (MemoryBlock page : allocatedPages) { + totalPageSize += page.size(); + } + return sorter.getMemoryUsage() + totalPageSize; } @VisibleForTesting @@ -214,23 +230,23 @@ private void allocateSpaceForRecord(int requiredSpace) throws IOException { // TODO: we should track metrics on the amount of space wasted when we roll over to a new page // without using the free space at the end of the current page. We should also do this for // BytesToBytesMap. - if (requiredSpace > PAGE_SIZE) { + if (requiredSpace > pageSizeBytes) { throw new IOException("Required space " + requiredSpace + " is greater than page size (" + - PAGE_SIZE + ")"); + pageSizeBytes + ")"); } else { - final long memoryAcquired = shuffleMemoryManager.tryToAcquire(PAGE_SIZE); - if (memoryAcquired < PAGE_SIZE) { + final long memoryAcquired = shuffleMemoryManager.tryToAcquire(pageSizeBytes); + if (memoryAcquired < pageSizeBytes) { shuffleMemoryManager.release(memoryAcquired); spill(); - final long memoryAcquiredAfterSpilling = shuffleMemoryManager.tryToAcquire(PAGE_SIZE); - if (memoryAcquiredAfterSpilling != PAGE_SIZE) { + final long memoryAcquiredAfterSpilling = shuffleMemoryManager.tryToAcquire(pageSizeBytes); + if (memoryAcquiredAfterSpilling != pageSizeBytes) { shuffleMemoryManager.release(memoryAcquiredAfterSpilling); - throw new IOException("Unable to acquire " + PAGE_SIZE + " bytes of memory"); + throw new IOException("Unable to acquire " + pageSizeBytes + " bytes of memory"); } } - currentPage = memoryManager.allocatePage(PAGE_SIZE); + currentPage = memoryManager.allocatePage(pageSizeBytes); currentPagePosition = currentPage.getBaseOffset(); - freeSpaceInCurrentPage = PAGE_SIZE; + freeSpaceInCurrentPage = pageSizeBytes; allocatedPages.add(currentPage); } } diff --git a/core/src/main/scala/org/apache/spark/Accumulators.scala b/core/src/main/scala/org/apache/spark/Accumulators.scala index 2f4fcac890eef..eb75f26718e19 100644 --- a/core/src/main/scala/org/apache/spark/Accumulators.scala +++ b/core/src/main/scala/org/apache/spark/Accumulators.scala @@ -341,7 +341,4 @@ private[spark] object Accumulators extends Logging { } } - def stringifyPartialValue(partialValue: Any): String = "%s".format(partialValue) - - def stringifyValue(value: Any): String = "%s".format(value) } diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index 6cf36fbbd6254..4161792976c7b 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -18,11 +18,12 @@ package org.apache.spark import java.util.concurrent.ConcurrentHashMap -import java.util.concurrent.atomic.AtomicBoolean import scala.collection.JavaConverters._ import scala.collection.mutable.LinkedHashSet +import org.apache.avro.{SchemaNormalization, Schema} + import org.apache.spark.serializer.KryoSerializer import org.apache.spark.util.Utils @@ -161,6 +162,26 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { this } + private final val avroNamespace = "avro.schema." + + /** + * Use Kryo serialization and register the given set of Avro schemas so that the generic + * record serializer can decrease network IO + */ + def registerAvroSchemas(schemas: Schema*): SparkConf = { + for (schema <- schemas) { + set(avroNamespace + SchemaNormalization.parsingFingerprint64(schema), schema.toString) + } + this + } + + /** Gets all the avro schemas in the configuration used in the generic Avro record serializer */ + def getAvroSchema: Map[Long, String] = { + getAll.filter { case (k, v) => k.startsWith(avroNamespace) } + .map { case (k, v) => (k.substring(avroNamespace.length).toLong, v) } + .toMap + } + /** Remove a parameter from the configuration */ def remove(key: String): SparkConf = { settings.remove(key) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 598953ac3bcc8..55e563ee968be 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -207,6 +207,7 @@ private[spark] class PythonRDD( override def run(): Unit = Utils.logUncaughtExceptions { try { + TaskContext.setTaskContext(context) val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize) val dataOut = new DataOutputStream(stream) // Partition index @@ -263,11 +264,6 @@ private[spark] class PythonRDD( if (!worker.isClosed) { Utils.tryLog(worker.shutdownOutput()) } - } finally { - // Release memory used by this thread for shuffles - env.shuffleMemoryManager.releaseMemoryForThisThread() - // Release memory used by this thread for unrolling blocks - env.blockManager.memoryStore.releaseUnrollMemoryForThisThread() } } } diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala index a5de10fe89c42..14dac4ed28ce3 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala @@ -69,8 +69,11 @@ private[r] class RBackendHandler(server: RBackend) case e: Exception => logError(s"Removing $objId failed", e) writeInt(dos, -1) + writeString(dos, s"Removing $objId failed: ${e.getMessage}") } - case _ => dos.writeInt(-1) + case _ => + dos.writeInt(-1) + writeString(dos, s"Error: unknown method $methodName") } } else { handleMethodCall(isStatic, objId, methodName, numArgs, dis, dos) @@ -146,8 +149,11 @@ private[r] class RBackendHandler(server: RBackend) } } catch { case e: Exception => - logError(s"$methodName on $objId failed", e) + logError(s"$methodName on $objId failed") writeInt(dos, -1) + // Writing the error message of the cause for the exception. This will be returned + // to user in the R process. + writeString(dos, Utils.exceptionString(e.getCause)) } } diff --git a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala index 23a470d6afcae..1cf2824f862ee 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala @@ -112,6 +112,7 @@ private abstract class BaseRRDD[T: ClassTag, U: ClassTag]( partition: Int): Unit = { val env = SparkEnv.get + val taskContext = TaskContext.get() val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt val stream = new BufferedOutputStream(output, bufferSize) @@ -119,6 +120,7 @@ private abstract class BaseRRDD[T: ClassTag, U: ClassTag]( override def run(): Unit = { try { SparkEnv.set(env) + TaskContext.setTaskContext(taskContext) val dataOut = new DataOutputStream(stream) dataOut.writeInt(partition) diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index e76664f1bd7b0..7bc7fce7ae8dd 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -313,10 +313,6 @@ private[spark] class Executor( } } finally { - // Release memory used by this thread for shuffles - env.shuffleMemoryManager.releaseMemoryForThisThread() - // Release memory used by this thread for unrolling blocks - env.blockManager.memoryStore.releaseUnrollMemoryForThisThread() runningTasks.remove(taskId) } } diff --git a/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala index defdabf95ac4b..3bb9998e1db44 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala @@ -133,6 +133,7 @@ private[spark] class PipedRDD[T: ClassTag]( // Start a thread to feed the process input from our parent's iterator new Thread("stdin writer for " + command) { override def run() { + TaskContext.setTaskContext(context) val out = new PrintWriter(proc.getOutputStream) // scalastyle:off println diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SqlNewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala similarity index 91% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/SqlNewHadoopRDD.scala rename to core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala index 3d75b6a91def6..35e44cb59c1be 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SqlNewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala @@ -15,12 +15,13 @@ * limitations under the License. */ -package org.apache.spark.sql.execution +package org.apache.spark.rdd import java.text.SimpleDateFormat import java.util.Date -import org.apache.spark.{Partition => SparkPartition, _} +import scala.reflect.ClassTag + import org.apache.hadoop.conf.{Configurable, Configuration} import org.apache.hadoop.io.Writable import org.apache.hadoop.mapreduce._ @@ -30,12 +31,12 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.executor.DataReadMethod import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil +import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.{Partition => SparkPartition, _} import org.apache.spark.rdd.NewHadoopRDD.NewHadoopMapPartitionsWithSplitRDD -import org.apache.spark.rdd.{HadoopRDD, RDD} import org.apache.spark.storage.StorageLevel import org.apache.spark.util.{SerializableConfiguration, Utils} -import scala.reflect.ClassTag private[spark] class SqlNewHadoopPartition( rddId: Int, @@ -62,7 +63,7 @@ private[spark] class SqlNewHadoopPartition( * changes based on [[org.apache.spark.rdd.HadoopRDD]]. In future, this functionality will be * folded into core. */ -private[sql] class SqlNewHadoopRDD[K, V]( +private[spark] class SqlNewHadoopRDD[K, V]( @transient sc : SparkContext, broadcastedConf: Broadcast[SerializableConfiguration], @transient initDriverSideJobFuncOpt: Option[Job => Unit], @@ -128,6 +129,12 @@ private[sql] class SqlNewHadoopRDD[K, V]( val inputMetrics = context.taskMetrics .getInputMetricsForReadMethod(DataReadMethod.Hadoop) + // Sets the thread local variable for the file's name + split.serializableHadoopSplit.value match { + case fs: FileSplit => SqlNewHadoopRDD.setInputFileName(fs.getPath.toString) + case _ => SqlNewHadoopRDD.unsetInputFileName() + } + // Find a function that will return the FileSystem bytes read by this thread. Do this before // creating RecordReader, because RecordReader's constructor might read some bytes val bytesReadCallback = inputMetrics.bytesReadCallback.orElse { @@ -188,6 +195,8 @@ private[sql] class SqlNewHadoopRDD[K, V]( reader.close() reader = null + SqlNewHadoopRDD.unsetInputFileName() + if (bytesReadCallback.isDefined) { inputMetrics.updateBytesRead() } else if (split.serializableHadoopSplit.value.isInstanceOf[FileSplit] || @@ -250,6 +259,21 @@ private[sql] class SqlNewHadoopRDD[K, V]( } private[spark] object SqlNewHadoopRDD { + + /** + * The thread variable for the name of the current file being read. This is used by + * the InputFileName function in Spark SQL. + */ + private[this] val inputFileName: ThreadLocal[UTF8String] = new ThreadLocal[UTF8String] { + override protected def initialValue(): UTF8String = UTF8String.fromString("") + } + + def getInputFileName(): UTF8String = inputFileName.get() + + private[spark] def setInputFileName(file: String) = inputFileName.set(UTF8String.fromString(file)) + + private[spark] def unsetInputFileName(): Unit = inputFileName.remove() + /** * Analogous to [[org.apache.spark.rdd.MapPartitionsRDD]], but passes in an InputSplit to * the given function rather than the index of the partition. diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index cdf6078421123..c4fa277c21254 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -916,11 +916,9 @@ class DAGScheduler( // To avoid UI cruft, ignore cases where value wasn't updated if (acc.name.isDefined && partialValue != acc.zero) { val name = acc.name.get - val stringPartialValue = Accumulators.stringifyPartialValue(partialValue) - val stringValue = Accumulators.stringifyValue(acc.value) - stage.latestInfo.accumulables(id) = AccumulableInfo(id, name, stringValue) + stage.latestInfo.accumulables(id) = AccumulableInfo(id, name, s"${acc.value}") event.taskInfo.accumulables += - AccumulableInfo(id, name, Some(stringPartialValue), stringValue) + AccumulableInfo(id, name, Some(s"$partialValue"), s"${acc.value}") } } } catch { diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index d11a00956a9a9..1978305cfefbd 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -23,7 +23,7 @@ import java.nio.ByteBuffer import scala.collection.mutable.HashMap import org.apache.spark.metrics.MetricsSystem -import org.apache.spark.{TaskContextImpl, TaskContext} +import org.apache.spark.{SparkEnv, TaskContextImpl, TaskContext} import org.apache.spark.executor.TaskMetrics import org.apache.spark.serializer.SerializerInstance import org.apache.spark.unsafe.memory.TaskMemoryManager @@ -86,7 +86,18 @@ private[spark] abstract class Task[T]( (runTask(context), context.collectAccumulators()) } finally { context.markTaskCompleted() - TaskContext.unset() + try { + Utils.tryLogNonFatalError { + // Release memory used by this thread for shuffles + SparkEnv.get.shuffleMemoryManager.releaseMemoryForThisTask() + } + Utils.tryLogNonFatalError { + // Release memory used by this thread for unrolling blocks + SparkEnv.get.blockManager.memoryStore.releaseUnrollMemoryForThisTask() + } + } finally { + TaskContext.unset() + } } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 660702f6e6fd0..bd89160af4ffa 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -241,7 +241,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp scheduler.executorLost(executorId, SlaveLost(reason)) listenerBus.post( SparkListenerExecutorRemoved(System.currentTimeMillis(), executorId, reason)) - case None => logError(s"Asked to remove non-existent executor $executorId") + case None => logInfo(s"Asked to remove non-existent executor $executorId") } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala index 074282d1be37d..044f6288fabdd 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala @@ -109,6 +109,8 @@ private[spark] abstract class YarnSchedulerBackend( case AddWebUIFilter(filterName, filterParams, proxyBase) => addWebUIFilter(filterName, filterParams, proxyBase) + case RemoveExecutor(executorId, reason) => + removeExecutor(executorId, reason) } override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { diff --git a/core/src/main/scala/org/apache/spark/serializer/GenericAvroSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/GenericAvroSerializer.scala new file mode 100644 index 0000000000000..62f8aae7f2126 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/serializer/GenericAvroSerializer.scala @@ -0,0 +1,150 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.serializer + +import java.io.{ByteArrayInputStream, ByteArrayOutputStream} +import java.nio.ByteBuffer + +import scala.collection.mutable + +import com.esotericsoftware.kryo.{Kryo, Serializer => KSerializer} +import com.esotericsoftware.kryo.io.{Input => KryoInput, Output => KryoOutput} +import org.apache.avro.{Schema, SchemaNormalization} +import org.apache.avro.generic.{GenericData, GenericRecord} +import org.apache.avro.io._ +import org.apache.commons.io.IOUtils + +import org.apache.spark.{SparkException, SparkEnv} +import org.apache.spark.io.CompressionCodec + +/** + * Custom serializer used for generic Avro records. If the user registers the schemas + * ahead of time, then the schema's fingerprint will be sent with each message instead of the actual + * schema, as to reduce network IO. + * Actions like parsing or compressing schemas are computationally expensive so the serializer + * caches all previously seen values as to reduce the amount of work needed to do. + * @param schemas a map where the keys are unique IDs for Avro schemas and the values are the + * string representation of the Avro schema, used to decrease the amount of data + * that needs to be serialized. + */ +private[serializer] class GenericAvroSerializer(schemas: Map[Long, String]) + extends KSerializer[GenericRecord] { + + /** Used to reduce the amount of effort to compress the schema */ + private val compressCache = new mutable.HashMap[Schema, Array[Byte]]() + private val decompressCache = new mutable.HashMap[ByteBuffer, Schema]() + + /** Reuses the same datum reader/writer since the same schema will be used many times */ + private val writerCache = new mutable.HashMap[Schema, DatumWriter[_]]() + private val readerCache = new mutable.HashMap[Schema, DatumReader[_]]() + + /** Fingerprinting is very expensive so this alleviates most of the work */ + private val fingerprintCache = new mutable.HashMap[Schema, Long]() + private val schemaCache = new mutable.HashMap[Long, Schema]() + + // GenericAvroSerializer can't take a SparkConf in the constructor b/c then it would become + // a member of KryoSerializer, which would make KryoSerializer not Serializable. We make + // the codec lazy here just b/c in some unit tests, we use a KryoSerializer w/out having + // the SparkEnv set (note those tests would fail if they tried to serialize avro data). + private lazy val codec = CompressionCodec.createCodec(SparkEnv.get.conf) + + /** + * Used to compress Schemas when they are being sent over the wire. + * The compression results are memoized to reduce the compression time since the + * same schema is compressed many times over + */ + def compress(schema: Schema): Array[Byte] = compressCache.getOrElseUpdate(schema, { + val bos = new ByteArrayOutputStream() + val out = codec.compressedOutputStream(bos) + out.write(schema.toString.getBytes("UTF-8")) + out.close() + bos.toByteArray + }) + + /** + * Decompresses the schema into the actual in-memory object. Keeps an internal cache of already + * seen values so to limit the number of times that decompression has to be done. + */ + def decompress(schemaBytes: ByteBuffer): Schema = decompressCache.getOrElseUpdate(schemaBytes, { + val bis = new ByteArrayInputStream(schemaBytes.array()) + val bytes = IOUtils.toByteArray(codec.compressedInputStream(bis)) + new Schema.Parser().parse(new String(bytes, "UTF-8")) + }) + + /** + * Serializes a record to the given output stream. It caches a lot of the internal data as + * to not redo work + */ + def serializeDatum[R <: GenericRecord](datum: R, output: KryoOutput): Unit = { + val encoder = EncoderFactory.get.binaryEncoder(output, null) + val schema = datum.getSchema + val fingerprint = fingerprintCache.getOrElseUpdate(schema, { + SchemaNormalization.parsingFingerprint64(schema) + }) + schemas.get(fingerprint) match { + case Some(_) => + output.writeBoolean(true) + output.writeLong(fingerprint) + case None => + output.writeBoolean(false) + val compressedSchema = compress(schema) + output.writeInt(compressedSchema.length) + output.writeBytes(compressedSchema) + } + + writerCache.getOrElseUpdate(schema, GenericData.get.createDatumWriter(schema)) + .asInstanceOf[DatumWriter[R]] + .write(datum, encoder) + encoder.flush() + } + + /** + * Deserializes generic records into their in-memory form. There is internal + * state to keep a cache of already seen schemas and datum readers. + */ + def deserializeDatum(input: KryoInput): GenericRecord = { + val schema = { + if (input.readBoolean()) { + val fingerprint = input.readLong() + schemaCache.getOrElseUpdate(fingerprint, { + schemas.get(fingerprint) match { + case Some(s) => new Schema.Parser().parse(s) + case None => + throw new SparkException( + "Error reading attempting to read avro data -- encountered an unknown " + + s"fingerprint: $fingerprint, not sure what schema to use. This could happen " + + "if you registered additional schemas after starting your spark context.") + } + }) + } else { + val length = input.readInt() + decompress(ByteBuffer.wrap(input.readBytes(length))) + } + } + val decoder = DecoderFactory.get.directBinaryDecoder(input, null) + readerCache.getOrElseUpdate(schema, GenericData.get.createDatumReader(schema)) + .asInstanceOf[DatumReader[GenericRecord]] + .read(null, decoder) + } + + override def write(kryo: Kryo, output: KryoOutput, datum: GenericRecord): Unit = + serializeDatum(datum, output) + + override def read(kryo: Kryo, input: KryoInput, datumClass: Class[GenericRecord]): GenericRecord = + deserializeDatum(input) +} diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index 7cb6e080533ad..0ff7562e912ca 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -27,6 +27,7 @@ import com.esotericsoftware.kryo.{Kryo, KryoException} import com.esotericsoftware.kryo.io.{Input => KryoInput, Output => KryoOutput} import com.esotericsoftware.kryo.serializers.{JavaSerializer => KryoJavaSerializer} import com.twitter.chill.{AllScalaRegistrar, EmptyScalaKryoInstantiator} +import org.apache.avro.generic.{GenericData, GenericRecord} import org.roaringbitmap.{ArrayContainer, BitmapContainer, RoaringArray, RoaringBitmap} import org.apache.spark._ @@ -73,6 +74,8 @@ class KryoSerializer(conf: SparkConf) .split(',') .filter(!_.isEmpty) + private val avroSchemas = conf.getAvroSchema + def newKryoOutput(): KryoOutput = new KryoOutput(bufferSize, math.max(bufferSize, maxBufferSize)) def newKryo(): Kryo = { @@ -101,6 +104,9 @@ class KryoSerializer(conf: SparkConf) kryo.register(classOf[HttpBroadcast[_]], new KryoJavaSerializer()) kryo.register(classOf[PythonBroadcast], new KryoJavaSerializer()) + kryo.register(classOf[GenericRecord], new GenericAvroSerializer(avroSchemas)) + kryo.register(classOf[GenericData.Record], new GenericAvroSerializer(avroSchemas)) + try { // scalastyle:off classforname // Use the default classloader when calling the user registrator. diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala index 3bcc7178a3d8b..f038b722957b8 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala @@ -19,95 +19,101 @@ package org.apache.spark.shuffle import scala.collection.mutable -import org.apache.spark.{Logging, SparkException, SparkConf} +import org.apache.spark.{Logging, SparkException, SparkConf, TaskContext} /** - * Allocates a pool of memory to task threads for use in shuffle operations. Each disk-spilling + * Allocates a pool of memory to tasks for use in shuffle operations. Each disk-spilling * collection (ExternalAppendOnlyMap or ExternalSorter) used by these tasks can acquire memory * from this pool and release it as it spills data out. When a task ends, all its memory will be * released by the Executor. * - * This class tries to ensure that each thread gets a reasonable share of memory, instead of some - * thread ramping up to a large amount first and then causing others to spill to disk repeatedly. - * If there are N threads, it ensures that each thread can acquire at least 1 / 2N of the memory + * This class tries to ensure that each task gets a reasonable share of memory, instead of some + * task ramping up to a large amount first and then causing others to spill to disk repeatedly. + * If there are N tasks, it ensures that each tasks can acquire at least 1 / 2N of the memory * before it has to spill, and at most 1 / N. Because N varies dynamically, we keep track of the - * set of active threads and redo the calculations of 1 / 2N and 1 / N in waiting threads whenever + * set of active tasks and redo the calculations of 1 / 2N and 1 / N in waiting tasks whenever * this set changes. This is all done by synchronizing access on "this" to mutate state and using * wait() and notifyAll() to signal changes. */ private[spark] class ShuffleMemoryManager(maxMemory: Long) extends Logging { - private val threadMemory = new mutable.HashMap[Long, Long]() // threadId -> memory bytes + private val taskMemory = new mutable.HashMap[Long, Long]() // taskAttemptId -> memory bytes def this(conf: SparkConf) = this(ShuffleMemoryManager.getMaxMemory(conf)) + private def currentTaskAttemptId(): Long = { + // In case this is called on the driver, return an invalid task attempt id. + Option(TaskContext.get()).map(_.taskAttemptId()).getOrElse(-1L) + } + /** - * Try to acquire up to numBytes memory for the current thread, and return the number of bytes + * Try to acquire up to numBytes memory for the current task, and return the number of bytes * obtained, or 0 if none can be allocated. This call may block until there is enough free memory - * in some situations, to make sure each thread has a chance to ramp up to at least 1 / 2N of the - * total memory pool (where N is the # of active threads) before it is forced to spill. This can - * happen if the number of threads increases but an older thread had a lot of memory already. + * in some situations, to make sure each task has a chance to ramp up to at least 1 / 2N of the + * total memory pool (where N is the # of active tasks) before it is forced to spill. This can + * happen if the number of tasks increases but an older task had a lot of memory already. */ def tryToAcquire(numBytes: Long): Long = synchronized { - val threadId = Thread.currentThread().getId + val taskAttemptId = currentTaskAttemptId() assert(numBytes > 0, "invalid number of bytes requested: " + numBytes) - // Add this thread to the threadMemory map just so we can keep an accurate count of the number - // of active threads, to let other threads ramp down their memory in calls to tryToAcquire - if (!threadMemory.contains(threadId)) { - threadMemory(threadId) = 0L - notifyAll() // Will later cause waiting threads to wake up and check numThreads again + // Add this task to the taskMemory map just so we can keep an accurate count of the number + // of active tasks, to let other tasks ramp down their memory in calls to tryToAcquire + if (!taskMemory.contains(taskAttemptId)) { + taskMemory(taskAttemptId) = 0L + notifyAll() // Will later cause waiting tasks to wake up and check numThreads again } // Keep looping until we're either sure that we don't want to grant this request (because this - // thread would have more than 1 / numActiveThreads of the memory) or we have enough free - // memory to give it (we always let each thread get at least 1 / (2 * numActiveThreads)). + // task would have more than 1 / numActiveTasks of the memory) or we have enough free + // memory to give it (we always let each task get at least 1 / (2 * numActiveTasks)). while (true) { - val numActiveThreads = threadMemory.keys.size - val curMem = threadMemory(threadId) - val freeMemory = maxMemory - threadMemory.values.sum + val numActiveTasks = taskMemory.keys.size + val curMem = taskMemory(taskAttemptId) + val freeMemory = maxMemory - taskMemory.values.sum - // How much we can grant this thread; don't let it grow to more than 1 / numActiveThreads; + // How much we can grant this task; don't let it grow to more than 1 / numActiveTasks; // don't let it be negative - val maxToGrant = math.min(numBytes, math.max(0, (maxMemory / numActiveThreads) - curMem)) + val maxToGrant = math.min(numBytes, math.max(0, (maxMemory / numActiveTasks) - curMem)) - if (curMem < maxMemory / (2 * numActiveThreads)) { - // We want to let each thread get at least 1 / (2 * numActiveThreads) before blocking; - // if we can't give it this much now, wait for other threads to free up memory - // (this happens if older threads allocated lots of memory before N grew) - if (freeMemory >= math.min(maxToGrant, maxMemory / (2 * numActiveThreads) - curMem)) { + if (curMem < maxMemory / (2 * numActiveTasks)) { + // We want to let each task get at least 1 / (2 * numActiveTasks) before blocking; + // if we can't give it this much now, wait for other tasks to free up memory + // (this happens if older tasks allocated lots of memory before N grew) + if (freeMemory >= math.min(maxToGrant, maxMemory / (2 * numActiveTasks) - curMem)) { val toGrant = math.min(maxToGrant, freeMemory) - threadMemory(threadId) += toGrant + taskMemory(taskAttemptId) += toGrant return toGrant } else { - logInfo(s"Thread $threadId waiting for at least 1/2N of shuffle memory pool to be free") + logInfo( + s"Thread $taskAttemptId waiting for at least 1/2N of shuffle memory pool to be free") wait() } } else { // Only give it as much memory as is free, which might be none if it reached 1 / numThreads val toGrant = math.min(maxToGrant, freeMemory) - threadMemory(threadId) += toGrant + taskMemory(taskAttemptId) += toGrant return toGrant } } 0L // Never reached } - /** Release numBytes bytes for the current thread. */ + /** Release numBytes bytes for the current task. */ def release(numBytes: Long): Unit = synchronized { - val threadId = Thread.currentThread().getId - val curMem = threadMemory.getOrElse(threadId, 0L) + val taskAttemptId = currentTaskAttemptId() + val curMem = taskMemory.getOrElse(taskAttemptId, 0L) if (curMem < numBytes) { throw new SparkException( - s"Internal error: release called on ${numBytes} bytes but thread only has ${curMem}") + s"Internal error: release called on ${numBytes} bytes but task only has ${curMem}") } - threadMemory(threadId) -= numBytes + taskMemory(taskAttemptId) -= numBytes notifyAll() // Notify waiters who locked "this" in tryToAcquire that memory has been freed } - /** Release all memory for the current thread and mark it as inactive (e.g. when a task ends). */ - def releaseMemoryForThisThread(): Unit = synchronized { - val threadId = Thread.currentThread().getId - threadMemory.remove(threadId) + /** Release all memory for the current task and mark it as inactive (e.g. when a task ends). */ + def releaseMemoryForThisTask(): Unit = synchronized { + val taskAttemptId = currentTaskAttemptId() + taskMemory.remove(taskAttemptId) notifyAll() // Notify waiters who locked "this" in tryToAcquire that memory has been freed } } diff --git a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala index ed609772e6979..6f27f00307f8c 100644 --- a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala @@ -23,6 +23,7 @@ import java.util.LinkedHashMap import scala.collection.mutable import scala.collection.mutable.ArrayBuffer +import org.apache.spark.TaskContext import org.apache.spark.util.{SizeEstimator, Utils} import org.apache.spark.util.collection.SizeTrackingVector @@ -43,11 +44,11 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) // Ensure only one thread is putting, and if necessary, dropping blocks at any given time private val accountingLock = new Object - // A mapping from thread ID to amount of memory used for unrolling a block (in bytes) + // A mapping from taskAttemptId to amount of memory used for unrolling a block (in bytes) // All accesses of this map are assumed to have manually synchronized on `accountingLock` private val unrollMemoryMap = mutable.HashMap[Long, Long]() // Same as `unrollMemoryMap`, but for pending unroll memory as defined below. - // Pending unroll memory refers to the intermediate memory occupied by a thread + // Pending unroll memory refers to the intermediate memory occupied by a task // after the unroll but before the actual putting of the block in the cache. // This chunk of memory is expected to be released *as soon as* we finish // caching the corresponding block as opposed to until after the task finishes. @@ -250,21 +251,21 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) var elementsUnrolled = 0 // Whether there is still enough memory for us to continue unrolling this block var keepUnrolling = true - // Initial per-thread memory to request for unrolling blocks (bytes). Exposed for testing. + // Initial per-task memory to request for unrolling blocks (bytes). Exposed for testing. val initialMemoryThreshold = unrollMemoryThreshold // How often to check whether we need to request more memory val memoryCheckPeriod = 16 - // Memory currently reserved by this thread for this particular unrolling operation + // Memory currently reserved by this task for this particular unrolling operation var memoryThreshold = initialMemoryThreshold // Memory to request as a multiple of current vector size val memoryGrowthFactor = 1.5 - // Previous unroll memory held by this thread, for releasing later (only at the very end) - val previousMemoryReserved = currentUnrollMemoryForThisThread + // Previous unroll memory held by this task, for releasing later (only at the very end) + val previousMemoryReserved = currentUnrollMemoryForThisTask // Underlying vector for unrolling the block var vector = new SizeTrackingVector[Any] // Request enough memory to begin unrolling - keepUnrolling = reserveUnrollMemoryForThisThread(initialMemoryThreshold) + keepUnrolling = reserveUnrollMemoryForThisTask(initialMemoryThreshold) if (!keepUnrolling) { logWarning(s"Failed to reserve initial memory threshold of " + @@ -283,7 +284,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) // Hold the accounting lock, in case another thread concurrently puts a block that // takes up the unrolling space we just ensured here accountingLock.synchronized { - if (!reserveUnrollMemoryForThisThread(amountToRequest)) { + if (!reserveUnrollMemoryForThisTask(amountToRequest)) { // If the first request is not granted, try again after ensuring free space // If there is still not enough space, give up and drop the partition val spaceToEnsure = maxUnrollMemory - currentUnrollMemory @@ -291,7 +292,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) val result = ensureFreeSpace(blockId, spaceToEnsure) droppedBlocks ++= result.droppedBlocks } - keepUnrolling = reserveUnrollMemoryForThisThread(amountToRequest) + keepUnrolling = reserveUnrollMemoryForThisTask(amountToRequest) } } // New threshold is currentSize * memoryGrowthFactor @@ -317,9 +318,9 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) // later when the task finishes. if (keepUnrolling) { accountingLock.synchronized { - val amountToRelease = currentUnrollMemoryForThisThread - previousMemoryReserved - releaseUnrollMemoryForThisThread(amountToRelease) - reservePendingUnrollMemoryForThisThread(amountToRelease) + val amountToRelease = currentUnrollMemoryForThisTask - previousMemoryReserved + releaseUnrollMemoryForThisTask(amountToRelease) + reservePendingUnrollMemoryForThisTask(amountToRelease) } } } @@ -397,7 +398,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) droppedBlockStatus.foreach { status => droppedBlocks += ((blockId, status)) } } // Release the unroll memory used because we no longer need the underlying Array - releasePendingUnrollMemoryForThisThread() + releasePendingUnrollMemoryForThisTask() } ResultWithDroppedBlocks(putSuccess, droppedBlocks) } @@ -427,9 +428,9 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) // Take into account the amount of memory currently occupied by unrolling blocks // and minus the pending unroll memory for that block on current thread. - val threadId = Thread.currentThread().getId + val taskAttemptId = currentTaskAttemptId() val actualFreeMemory = freeMemory - currentUnrollMemory + - pendingUnrollMemoryMap.getOrElse(threadId, 0L) + pendingUnrollMemoryMap.getOrElse(taskAttemptId, 0L) if (actualFreeMemory < space) { val rddToAdd = getRddId(blockIdToAdd) @@ -455,7 +456,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) logInfo(s"${selectedBlocks.size} blocks selected for dropping") for (blockId <- selectedBlocks) { val entry = entries.synchronized { entries.get(blockId) } - // This should never be null as only one thread should be dropping + // This should never be null as only one task should be dropping // blocks and removing entries. However the check is still here for // future safety. if (entry != null) { @@ -482,79 +483,85 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) entries.synchronized { entries.containsKey(blockId) } } + private def currentTaskAttemptId(): Long = { + // In case this is called on the driver, return an invalid task attempt id. + Option(TaskContext.get()).map(_.taskAttemptId()).getOrElse(-1L) + } + /** - * Reserve additional memory for unrolling blocks used by this thread. + * Reserve additional memory for unrolling blocks used by this task. * Return whether the request is granted. */ - def reserveUnrollMemoryForThisThread(memory: Long): Boolean = { + def reserveUnrollMemoryForThisTask(memory: Long): Boolean = { accountingLock.synchronized { val granted = freeMemory > currentUnrollMemory + memory if (granted) { - val threadId = Thread.currentThread().getId - unrollMemoryMap(threadId) = unrollMemoryMap.getOrElse(threadId, 0L) + memory + val taskAttemptId = currentTaskAttemptId() + unrollMemoryMap(taskAttemptId) = unrollMemoryMap.getOrElse(taskAttemptId, 0L) + memory } granted } } /** - * Release memory used by this thread for unrolling blocks. - * If the amount is not specified, remove the current thread's allocation altogether. + * Release memory used by this task for unrolling blocks. + * If the amount is not specified, remove the current task's allocation altogether. */ - def releaseUnrollMemoryForThisThread(memory: Long = -1L): Unit = { - val threadId = Thread.currentThread().getId + def releaseUnrollMemoryForThisTask(memory: Long = -1L): Unit = { + val taskAttemptId = currentTaskAttemptId() accountingLock.synchronized { if (memory < 0) { - unrollMemoryMap.remove(threadId) + unrollMemoryMap.remove(taskAttemptId) } else { - unrollMemoryMap(threadId) = unrollMemoryMap.getOrElse(threadId, memory) - memory - // If this thread claims no more unroll memory, release it completely - if (unrollMemoryMap(threadId) <= 0) { - unrollMemoryMap.remove(threadId) + unrollMemoryMap(taskAttemptId) = unrollMemoryMap.getOrElse(taskAttemptId, memory) - memory + // If this task claims no more unroll memory, release it completely + if (unrollMemoryMap(taskAttemptId) <= 0) { + unrollMemoryMap.remove(taskAttemptId) } } } } /** - * Reserve the unroll memory of current unroll successful block used by this thread + * Reserve the unroll memory of current unroll successful block used by this task * until actually put the block into memory entry. */ - def reservePendingUnrollMemoryForThisThread(memory: Long): Unit = { - val threadId = Thread.currentThread().getId + def reservePendingUnrollMemoryForThisTask(memory: Long): Unit = { + val taskAttemptId = currentTaskAttemptId() accountingLock.synchronized { - pendingUnrollMemoryMap(threadId) = pendingUnrollMemoryMap.getOrElse(threadId, 0L) + memory + pendingUnrollMemoryMap(taskAttemptId) = + pendingUnrollMemoryMap.getOrElse(taskAttemptId, 0L) + memory } } /** - * Release pending unroll memory of current unroll successful block used by this thread + * Release pending unroll memory of current unroll successful block used by this task */ - def releasePendingUnrollMemoryForThisThread(): Unit = { - val threadId = Thread.currentThread().getId + def releasePendingUnrollMemoryForThisTask(): Unit = { + val taskAttemptId = currentTaskAttemptId() accountingLock.synchronized { - pendingUnrollMemoryMap.remove(threadId) + pendingUnrollMemoryMap.remove(taskAttemptId) } } /** - * Return the amount of memory currently occupied for unrolling blocks across all threads. + * Return the amount of memory currently occupied for unrolling blocks across all tasks. */ def currentUnrollMemory: Long = accountingLock.synchronized { unrollMemoryMap.values.sum + pendingUnrollMemoryMap.values.sum } /** - * Return the amount of memory currently occupied for unrolling blocks by this thread. + * Return the amount of memory currently occupied for unrolling blocks by this task. */ - def currentUnrollMemoryForThisThread: Long = accountingLock.synchronized { - unrollMemoryMap.getOrElse(Thread.currentThread().getId, 0L) + def currentUnrollMemoryForThisTask: Long = accountingLock.synchronized { + unrollMemoryMap.getOrElse(currentTaskAttemptId(), 0L) } /** - * Return the number of threads currently unrolling blocks. + * Return the number of tasks currently unrolling blocks. */ - def numThreadsUnrolling: Int = accountingLock.synchronized { unrollMemoryMap.keys.size } + def numTasksUnrolling: Int = accountingLock.synchronized { unrollMemoryMap.keys.size } /** * Log information about current memory usage. @@ -566,7 +573,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) logInfo( s"Memory use = ${Utils.bytesToString(blocksMemory)} (blocks) + " + s"${Utils.bytesToString(unrollMemory)} (scratch space shared across " + - s"$numThreadsUnrolling thread(s)) = ${Utils.bytesToString(totalMemory)}. " + + s"$numTasksUnrolling tasks(s)) = ${Utils.bytesToString(totalMemory)}. " + s"Storage limit = ${Utils.bytesToString(maxMemory)}." ) } diff --git a/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala b/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala index 7d84468f62ab1..14b1f2a17e707 100644 --- a/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala +++ b/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala @@ -217,10 +217,10 @@ object SizeEstimator extends Logging { var arrSize: Long = alignSize(objectSize + INT_SIZE) if (elementClass.isPrimitive) { - arrSize += alignSize(length * primitiveSize(elementClass)) + arrSize += alignSize(length.toLong * primitiveSize(elementClass)) state.size += arrSize } else { - arrSize += alignSize(length * pointerSize) + arrSize += alignSize(length.toLong * pointerSize) state.size += arrSize if (length <= ARRAY_SIZE_FOR_SAMPLING) { @@ -336,7 +336,7 @@ object SizeEstimator extends Logging { // hg.openjdk.java.net/jdk8/jdk8/hotspot/file/tip/src/share/vm/classfile/classFileParser.cpp var alignedSize = shellSize for (size <- fieldSizes if sizeCount(size) > 0) { - val count = sizeCount(size) + val count = sizeCount(size).toLong // If there are internal gaps, smaller field can fit in. alignedSize = math.max(alignedSize, alignSizeUp(shellSize, size) + size * count) shellSize += size * count diff --git a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java index 10c3eedbf4b46..04fc09b323dbb 100644 --- a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java @@ -111,7 +111,7 @@ public void setUp() throws IOException { mergedOutputFile = File.createTempFile("mergedoutput", "", tempDir); partitionSizesInMergedFile = null; spillFilesCreated.clear(); - conf = new SparkConf(); + conf = new SparkConf().set("spark.buffer.pageSize", "128m"); taskMetrics = new TaskMetrics(); when(shuffleMemoryManager.tryToAcquire(anyLong())).then(returnsFirstArg()); @@ -512,12 +512,12 @@ public void close() { } writer.insertRecordIntoSorter(new Tuple2(new byte[1], new byte[1])); writer.forceSorterToSpill(); // We should be able to write a record that's right _at_ the max record size - final byte[] atMaxRecordSize = new byte[UnsafeShuffleExternalSorter.MAX_RECORD_SIZE]; + final byte[] atMaxRecordSize = new byte[writer.maxRecordSizeBytes()]; new Random(42).nextBytes(atMaxRecordSize); writer.insertRecordIntoSorter(new Tuple2(new byte[0], atMaxRecordSize)); writer.forceSorterToSpill(); // Inserting a record that's larger than the max record size should fail: - final byte[] exceedsMaxRecordSize = new byte[UnsafeShuffleExternalSorter.MAX_RECORD_SIZE + 1]; + final byte[] exceedsMaxRecordSize = new byte[writer.maxRecordSizeBytes() + 1]; new Random(42).nextBytes(exceedsMaxRecordSize); Product2 hugeRecord = new Tuple2(new byte[0], exceedsMaxRecordSize); diff --git a/core/src/test/scala/org/apache/spark/serializer/GenericAvroSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/GenericAvroSerializerSuite.scala new file mode 100644 index 0000000000000..bc9f3708ed69d --- /dev/null +++ b/core/src/test/scala/org/apache/spark/serializer/GenericAvroSerializerSuite.scala @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.serializer + +import java.io.{ByteArrayInputStream, ByteArrayOutputStream} +import java.nio.ByteBuffer + +import com.esotericsoftware.kryo.io.{Output, Input} +import org.apache.avro.{SchemaBuilder, Schema} +import org.apache.avro.generic.GenericData.Record + +import org.apache.spark.{SparkFunSuite, SharedSparkContext} + +class GenericAvroSerializerSuite extends SparkFunSuite with SharedSparkContext { + conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") + + val schema : Schema = SchemaBuilder + .record("testRecord").fields() + .requiredString("data") + .endRecord() + val record = new Record(schema) + record.put("data", "test data") + + test("schema compression and decompression") { + val genericSer = new GenericAvroSerializer(conf.getAvroSchema) + assert(schema === genericSer.decompress(ByteBuffer.wrap(genericSer.compress(schema)))) + } + + test("record serialization and deserialization") { + val genericSer = new GenericAvroSerializer(conf.getAvroSchema) + + val outputStream = new ByteArrayOutputStream() + val output = new Output(outputStream) + genericSer.serializeDatum(record, output) + output.flush() + output.close() + + val input = new Input(new ByteArrayInputStream(outputStream.toByteArray)) + assert(genericSer.deserializeDatum(input) === record) + } + + test("uses schema fingerprint to decrease message size") { + val genericSerFull = new GenericAvroSerializer(conf.getAvroSchema) + + val output = new Output(new ByteArrayOutputStream()) + + val beginningNormalPosition = output.total() + genericSerFull.serializeDatum(record, output) + output.flush() + val normalLength = output.total - beginningNormalPosition + + conf.registerAvroSchemas(schema) + val genericSerFinger = new GenericAvroSerializer(conf.getAvroSchema) + val beginningFingerprintPosition = output.total() + genericSerFinger.serializeDatum(record, output) + val fingerprintLength = output.total - beginningFingerprintPosition + + assert(fingerprintLength < normalLength) + } + + test("caches previously seen schemas") { + val genericSer = new GenericAvroSerializer(conf.getAvroSchema) + val compressedSchema = genericSer.compress(schema) + val decompressedScheam = genericSer.decompress(ByteBuffer.wrap(compressedSchema)) + + assert(compressedSchema.eq(genericSer.compress(schema))) + assert(decompressedScheam.eq(genericSer.decompress(ByteBuffer.wrap(compressedSchema)))) + } +} diff --git a/core/src/test/scala/org/apache/spark/shuffle/ShuffleMemoryManagerSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/ShuffleMemoryManagerSuite.scala index 96778c9ebafb1..f495b6a037958 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/ShuffleMemoryManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/ShuffleMemoryManagerSuite.scala @@ -17,26 +17,39 @@ package org.apache.spark.shuffle +import java.util.concurrent.CountDownLatch +import java.util.concurrent.atomic.AtomicInteger + +import org.mockito.Mockito._ import org.scalatest.concurrent.Timeouts import org.scalatest.time.SpanSugar._ -import java.util.concurrent.atomic.AtomicBoolean -import java.util.concurrent.CountDownLatch -import org.apache.spark.SparkFunSuite +import org.apache.spark.{SparkFunSuite, TaskContext} class ShuffleMemoryManagerSuite extends SparkFunSuite with Timeouts { + + val nextTaskAttemptId = new AtomicInteger() + /** Launch a thread with the given body block and return it. */ private def startThread(name: String)(body: => Unit): Thread = { val thread = new Thread("ShuffleMemorySuite " + name) { override def run() { - body + try { + val taskAttemptId = nextTaskAttemptId.getAndIncrement + val mockTaskContext = mock(classOf[TaskContext], RETURNS_SMART_NULLS) + when(mockTaskContext.taskAttemptId()).thenReturn(taskAttemptId) + TaskContext.setTaskContext(mockTaskContext) + body + } finally { + TaskContext.unset() + } } } thread.start() thread } - test("single thread requesting memory") { + test("single task requesting memory") { val manager = new ShuffleMemoryManager(1000L) assert(manager.tryToAcquire(100L) === 100L) @@ -50,7 +63,7 @@ class ShuffleMemoryManagerSuite extends SparkFunSuite with Timeouts { assert(manager.tryToAcquire(300L) === 300L) assert(manager.tryToAcquire(300L) === 200L) - manager.releaseMemoryForThisThread() + manager.releaseMemoryForThisTask() assert(manager.tryToAcquire(1000L) === 1000L) assert(manager.tryToAcquire(100L) === 0L) } @@ -107,8 +120,8 @@ class ShuffleMemoryManagerSuite extends SparkFunSuite with Timeouts { } - test("threads cannot grow past 1 / N") { - // Two threads request 250 bytes first, wait for each other to get it, and then request + test("tasks cannot grow past 1 / N") { + // Two tasks request 250 bytes first, wait for each other to get it, and then request // 500 more; we should only grant 250 bytes to each of them on this second request val manager = new ShuffleMemoryManager(1000L) @@ -158,7 +171,7 @@ class ShuffleMemoryManagerSuite extends SparkFunSuite with Timeouts { assert(state.t2Result2 === 250L) } - test("threads can block to get at least 1 / 2N memory") { + test("tasks can block to get at least 1 / 2N memory") { // t1 grabs 1000 bytes and then waits until t2 is ready to make a request. It sleeps // for a bit and releases 250 bytes, which should then be granted to t2. Further requests // by t2 will return false right away because it now has 1 / 2N of the memory. @@ -224,7 +237,7 @@ class ShuffleMemoryManagerSuite extends SparkFunSuite with Timeouts { } } - test("releaseMemoryForThisThread") { + test("releaseMemoryForThisTask") { // t1 grabs 1000 bytes and then waits until t2 is ready to make a request. It sleeps // for a bit and releases all its memory. t2 should now be able to grab all the memory. @@ -251,9 +264,9 @@ class ShuffleMemoryManagerSuite extends SparkFunSuite with Timeouts { } } // Sleep a bit before releasing our memory; this is hacky but it would be difficult to make - // sure the other thread blocks for some time otherwise + // sure the other task blocks for some time otherwise Thread.sleep(300) - manager.releaseMemoryForThisThread() + manager.releaseMemoryForThisTask() } val t2 = startThread("t2") { @@ -282,7 +295,7 @@ class ShuffleMemoryManagerSuite extends SparkFunSuite with Timeouts { t2.join() } - // Both threads should've been able to acquire their memory; the second one will have waited + // Both tasks should've been able to acquire their memory; the second one will have waited // until the first one acquired 1000 bytes and then released all of it state.synchronized { assert(state.t1Result === 1000L, "t1 could not allocate memory") @@ -293,7 +306,7 @@ class ShuffleMemoryManagerSuite extends SparkFunSuite with Timeouts { } } - test("threads should not be granted a negative size") { + test("tasks should not be granted a negative size") { val manager = new ShuffleMemoryManager(1000L) manager.tryToAcquire(700L) diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index bcee901f5dd5f..f480fd107a0c2 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -1004,32 +1004,32 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE store = makeBlockManager(12000) val memoryStore = store.memoryStore assert(memoryStore.currentUnrollMemory === 0) - assert(memoryStore.currentUnrollMemoryForThisThread === 0) + assert(memoryStore.currentUnrollMemoryForThisTask === 0) // Reserve - memoryStore.reserveUnrollMemoryForThisThread(100) - assert(memoryStore.currentUnrollMemoryForThisThread === 100) - memoryStore.reserveUnrollMemoryForThisThread(200) - assert(memoryStore.currentUnrollMemoryForThisThread === 300) - memoryStore.reserveUnrollMemoryForThisThread(500) - assert(memoryStore.currentUnrollMemoryForThisThread === 800) - memoryStore.reserveUnrollMemoryForThisThread(1000000) - assert(memoryStore.currentUnrollMemoryForThisThread === 800) // not granted + memoryStore.reserveUnrollMemoryForThisTask(100) + assert(memoryStore.currentUnrollMemoryForThisTask === 100) + memoryStore.reserveUnrollMemoryForThisTask(200) + assert(memoryStore.currentUnrollMemoryForThisTask === 300) + memoryStore.reserveUnrollMemoryForThisTask(500) + assert(memoryStore.currentUnrollMemoryForThisTask === 800) + memoryStore.reserveUnrollMemoryForThisTask(1000000) + assert(memoryStore.currentUnrollMemoryForThisTask === 800) // not granted // Release - memoryStore.releaseUnrollMemoryForThisThread(100) - assert(memoryStore.currentUnrollMemoryForThisThread === 700) - memoryStore.releaseUnrollMemoryForThisThread(100) - assert(memoryStore.currentUnrollMemoryForThisThread === 600) + memoryStore.releaseUnrollMemoryForThisTask(100) + assert(memoryStore.currentUnrollMemoryForThisTask === 700) + memoryStore.releaseUnrollMemoryForThisTask(100) + assert(memoryStore.currentUnrollMemoryForThisTask === 600) // Reserve again - memoryStore.reserveUnrollMemoryForThisThread(4400) - assert(memoryStore.currentUnrollMemoryForThisThread === 5000) - memoryStore.reserveUnrollMemoryForThisThread(20000) - assert(memoryStore.currentUnrollMemoryForThisThread === 5000) // not granted + memoryStore.reserveUnrollMemoryForThisTask(4400) + assert(memoryStore.currentUnrollMemoryForThisTask === 5000) + memoryStore.reserveUnrollMemoryForThisTask(20000) + assert(memoryStore.currentUnrollMemoryForThisTask === 5000) // not granted // Release again - memoryStore.releaseUnrollMemoryForThisThread(1000) - assert(memoryStore.currentUnrollMemoryForThisThread === 4000) - memoryStore.releaseUnrollMemoryForThisThread() // release all - assert(memoryStore.currentUnrollMemoryForThisThread === 0) + memoryStore.releaseUnrollMemoryForThisTask(1000) + assert(memoryStore.currentUnrollMemoryForThisTask === 4000) + memoryStore.releaseUnrollMemoryForThisTask() // release all + assert(memoryStore.currentUnrollMemoryForThisTask === 0) } /** @@ -1060,24 +1060,24 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE val bigList = List.fill(40)(new Array[Byte](1000)) val memoryStore = store.memoryStore val droppedBlocks = new ArrayBuffer[(BlockId, BlockStatus)] - assert(memoryStore.currentUnrollMemoryForThisThread === 0) + assert(memoryStore.currentUnrollMemoryForThisTask === 0) // Unroll with all the space in the world. This should succeed and return an array. var unrollResult = memoryStore.unrollSafely("unroll", smallList.iterator, droppedBlocks) verifyUnroll(smallList.iterator, unrollResult, shouldBeArray = true) - assert(memoryStore.currentUnrollMemoryForThisThread === 0) - memoryStore.releasePendingUnrollMemoryForThisThread() + assert(memoryStore.currentUnrollMemoryForThisTask === 0) + memoryStore.releasePendingUnrollMemoryForThisTask() // Unroll with not enough space. This should succeed after kicking out someBlock1. store.putIterator("someBlock1", smallList.iterator, StorageLevel.MEMORY_ONLY) store.putIterator("someBlock2", smallList.iterator, StorageLevel.MEMORY_ONLY) unrollResult = memoryStore.unrollSafely("unroll", smallList.iterator, droppedBlocks) verifyUnroll(smallList.iterator, unrollResult, shouldBeArray = true) - assert(memoryStore.currentUnrollMemoryForThisThread === 0) + assert(memoryStore.currentUnrollMemoryForThisTask === 0) assert(droppedBlocks.size === 1) assert(droppedBlocks.head._1 === TestBlockId("someBlock1")) droppedBlocks.clear() - memoryStore.releasePendingUnrollMemoryForThisThread() + memoryStore.releasePendingUnrollMemoryForThisTask() // Unroll huge block with not enough space. Even after ensuring free space of 12000 * 0.4 = // 4800 bytes, there is still not enough room to unroll this block. This returns an iterator. @@ -1085,7 +1085,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE store.putIterator("someBlock3", smallList.iterator, StorageLevel.MEMORY_ONLY) unrollResult = memoryStore.unrollSafely("unroll", bigList.iterator, droppedBlocks) verifyUnroll(bigList.iterator, unrollResult, shouldBeArray = false) - assert(memoryStore.currentUnrollMemoryForThisThread > 0) // we returned an iterator + assert(memoryStore.currentUnrollMemoryForThisTask > 0) // we returned an iterator assert(droppedBlocks.size === 1) assert(droppedBlocks.head._1 === TestBlockId("someBlock2")) droppedBlocks.clear() @@ -1099,7 +1099,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE val bigList = List.fill(40)(new Array[Byte](1000)) def smallIterator: Iterator[Any] = smallList.iterator.asInstanceOf[Iterator[Any]] def bigIterator: Iterator[Any] = bigList.iterator.asInstanceOf[Iterator[Any]] - assert(memoryStore.currentUnrollMemoryForThisThread === 0) + assert(memoryStore.currentUnrollMemoryForThisTask === 0) // Unroll with plenty of space. This should succeed and cache both blocks. val result1 = memoryStore.putIterator("b1", smallIterator, memOnly, returnValues = true) @@ -1110,7 +1110,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE assert(result2.size > 0) assert(result1.data.isLeft) // unroll did not drop this block to disk assert(result2.data.isLeft) - assert(memoryStore.currentUnrollMemoryForThisThread === 0) + assert(memoryStore.currentUnrollMemoryForThisTask === 0) // Re-put these two blocks so block manager knows about them too. Otherwise, block manager // would not know how to drop them from memory later. @@ -1126,7 +1126,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE assert(!memoryStore.contains("b1")) assert(memoryStore.contains("b2")) assert(memoryStore.contains("b3")) - assert(memoryStore.currentUnrollMemoryForThisThread === 0) + assert(memoryStore.currentUnrollMemoryForThisTask === 0) memoryStore.remove("b3") store.putIterator("b3", smallIterator, memOnly) @@ -1138,7 +1138,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE assert(!memoryStore.contains("b2")) assert(memoryStore.contains("b3")) assert(!memoryStore.contains("b4")) - assert(memoryStore.currentUnrollMemoryForThisThread > 0) // we returned an iterator + assert(memoryStore.currentUnrollMemoryForThisTask > 0) // we returned an iterator } /** @@ -1153,7 +1153,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE val bigList = List.fill(40)(new Array[Byte](1000)) def smallIterator: Iterator[Any] = smallList.iterator.asInstanceOf[Iterator[Any]] def bigIterator: Iterator[Any] = bigList.iterator.asInstanceOf[Iterator[Any]] - assert(memoryStore.currentUnrollMemoryForThisThread === 0) + assert(memoryStore.currentUnrollMemoryForThisTask === 0) store.putIterator("b1", smallIterator, memAndDisk) store.putIterator("b2", smallIterator, memAndDisk) @@ -1170,7 +1170,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE assert(!diskStore.contains("b3")) memoryStore.remove("b3") store.putIterator("b3", smallIterator, StorageLevel.MEMORY_ONLY) - assert(memoryStore.currentUnrollMemoryForThisThread === 0) + assert(memoryStore.currentUnrollMemoryForThisTask === 0) // Unroll huge block with not enough space. This should fail and drop the new block to disk // directly in addition to kicking out b2 in the process. Memory store should contain only @@ -1186,7 +1186,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE assert(diskStore.contains("b2")) assert(!diskStore.contains("b3")) assert(diskStore.contains("b4")) - assert(memoryStore.currentUnrollMemoryForThisThread > 0) // we returned an iterator + assert(memoryStore.currentUnrollMemoryForThisTask > 0) // we returned an iterator } test("multiple unrolls by the same thread") { @@ -1195,32 +1195,32 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE val memoryStore = store.memoryStore val smallList = List.fill(40)(new Array[Byte](100)) def smallIterator: Iterator[Any] = smallList.iterator.asInstanceOf[Iterator[Any]] - assert(memoryStore.currentUnrollMemoryForThisThread === 0) + assert(memoryStore.currentUnrollMemoryForThisTask === 0) // All unroll memory used is released because unrollSafely returned an array memoryStore.putIterator("b1", smallIterator, memOnly, returnValues = true) - assert(memoryStore.currentUnrollMemoryForThisThread === 0) + assert(memoryStore.currentUnrollMemoryForThisTask === 0) memoryStore.putIterator("b2", smallIterator, memOnly, returnValues = true) - assert(memoryStore.currentUnrollMemoryForThisThread === 0) + assert(memoryStore.currentUnrollMemoryForThisTask === 0) // Unroll memory is not released because unrollSafely returned an iterator // that still depends on the underlying vector used in the process memoryStore.putIterator("b3", smallIterator, memOnly, returnValues = true) - val unrollMemoryAfterB3 = memoryStore.currentUnrollMemoryForThisThread + val unrollMemoryAfterB3 = memoryStore.currentUnrollMemoryForThisTask assert(unrollMemoryAfterB3 > 0) // The unroll memory owned by this thread builds on top of its value after the previous unrolls memoryStore.putIterator("b4", smallIterator, memOnly, returnValues = true) - val unrollMemoryAfterB4 = memoryStore.currentUnrollMemoryForThisThread + val unrollMemoryAfterB4 = memoryStore.currentUnrollMemoryForThisTask assert(unrollMemoryAfterB4 > unrollMemoryAfterB3) // ... but only to a certain extent (until we run out of free space to grant new unroll memory) memoryStore.putIterator("b5", smallIterator, memOnly, returnValues = true) - val unrollMemoryAfterB5 = memoryStore.currentUnrollMemoryForThisThread + val unrollMemoryAfterB5 = memoryStore.currentUnrollMemoryForThisTask memoryStore.putIterator("b6", smallIterator, memOnly, returnValues = true) - val unrollMemoryAfterB6 = memoryStore.currentUnrollMemoryForThisThread + val unrollMemoryAfterB6 = memoryStore.currentUnrollMemoryForThisTask memoryStore.putIterator("b7", smallIterator, memOnly, returnValues = true) - val unrollMemoryAfterB7 = memoryStore.currentUnrollMemoryForThisThread + val unrollMemoryAfterB7 = memoryStore.currentUnrollMemoryForThisTask assert(unrollMemoryAfterB5 === unrollMemoryAfterB4) assert(unrollMemoryAfterB6 === unrollMemoryAfterB4) assert(unrollMemoryAfterB7 === unrollMemoryAfterB4) diff --git a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala index dc03e374b51db..26a2e96edaaa2 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala @@ -17,22 +17,29 @@ package org.apache.spark.util.collection.unsafe.sort +import com.google.common.primitives.UnsignedBytes import org.scalatest.prop.PropertyChecks - import org.apache.spark.SparkFunSuite +import org.apache.spark.unsafe.types.UTF8String class PrefixComparatorsSuite extends SparkFunSuite with PropertyChecks { test("String prefix comparator") { def testPrefixComparison(s1: String, s2: String): Unit = { - val s1Prefix = PrefixComparators.STRING.computePrefix(s1) - val s2Prefix = PrefixComparators.STRING.computePrefix(s2) + val utf8string1 = UTF8String.fromString(s1) + val utf8string2 = UTF8String.fromString(s2) + val s1Prefix = PrefixComparators.StringPrefixComparator.computePrefix(utf8string1) + val s2Prefix = PrefixComparators.StringPrefixComparator.computePrefix(utf8string2) val prefixComparisonResult = PrefixComparators.STRING.compare(s1Prefix, s2Prefix) + + val cmp = UnsignedBytes.lexicographicalComparator().compare( + utf8string1.getBytes.take(8), utf8string2.getBytes.take(8)) + assert( - (prefixComparisonResult == 0) || - (prefixComparisonResult < 0 && s1 < s2) || - (prefixComparisonResult > 0 && s1 > s2)) + (prefixComparisonResult == 0 && cmp == 0) || + (prefixComparisonResult < 0 && s1.compareTo(s2) < 0) || + (prefixComparisonResult > 0 && s1.compareTo(s2) > 0)) } // scalastyle:off @@ -48,27 +55,15 @@ class PrefixComparatorsSuite extends SparkFunSuite with PropertyChecks { forAll { (s1: String, s2: String) => testPrefixComparison(s1, s2) } } - test("float prefix comparator handles NaN properly") { - val nan1: Float = java.lang.Float.intBitsToFloat(0x7f800001) - val nan2: Float = java.lang.Float.intBitsToFloat(0x7fffffff) - assert(nan1.isNaN) - assert(nan2.isNaN) - val nan1Prefix = PrefixComparators.FLOAT.computePrefix(nan1) - val nan2Prefix = PrefixComparators.FLOAT.computePrefix(nan2) - assert(nan1Prefix === nan2Prefix) - val floatMaxPrefix = PrefixComparators.FLOAT.computePrefix(Float.MaxValue) - assert(PrefixComparators.FLOAT.compare(nan1Prefix, floatMaxPrefix) === 1) - } - test("double prefix comparator handles NaNs properly") { val nan1: Double = java.lang.Double.longBitsToDouble(0x7ff0000000000001L) val nan2: Double = java.lang.Double.longBitsToDouble(0x7fffffffffffffffL) assert(nan1.isNaN) assert(nan2.isNaN) - val nan1Prefix = PrefixComparators.DOUBLE.computePrefix(nan1) - val nan2Prefix = PrefixComparators.DOUBLE.computePrefix(nan2) + val nan1Prefix = PrefixComparators.DoublePrefixComparator.computePrefix(nan1) + val nan2Prefix = PrefixComparators.DoublePrefixComparator.computePrefix(nan2) assert(nan1Prefix === nan2Prefix) - val doubleMaxPrefix = PrefixComparators.DOUBLE.computePrefix(Double.MaxValue) + val doubleMaxPrefix = PrefixComparators.DoublePrefixComparator.computePrefix(Double.MaxValue) assert(PrefixComparators.DOUBLE.compare(nan1Prefix, doubleMaxPrefix) === 1) } diff --git a/dev/run-tests.py b/dev/run-tests.py index 1f0d218514f92..29420da9aa956 100755 --- a/dev/run-tests.py +++ b/dev/run-tests.py @@ -85,6 +85,13 @@ def identify_changed_files_from_git_commits(patch_sha, target_branch=None, targe return [f for f in raw_output.split('\n') if f] +def setup_test_environ(environ): + print("[info] Setup the following environment variables for tests: ") + for (k, v) in environ.items(): + print("%s=%s" % (k, v)) + os.environ[k] = v + + def determine_modules_to_test(changed_modules): """ Given a set of modules that have changed, compute the transitive closure of those modules' @@ -455,6 +462,15 @@ def main(): print("[info] Found the following changed modules:", ", ".join(x.name for x in changed_modules)) + # setup environment variables + # note - the 'root' module doesn't collect environment variables for all modules. Because the + # environment variables should not be set if a module is not changed, even if running the 'root' + # module. So here we should use changed_modules rather than test_modules. + test_environ = {} + for m in changed_modules: + test_environ.update(m.environ) + setup_test_environ(test_environ) + test_modules = determine_modules_to_test(changed_modules) # license checks diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 3073d489bad4a..44600cb9523c1 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -29,7 +29,7 @@ class Module(object): changed. """ - def __init__(self, name, dependencies, source_file_regexes, build_profile_flags=(), + def __init__(self, name, dependencies, source_file_regexes, build_profile_flags=(), environ={}, sbt_test_goals=(), python_test_goals=(), blacklisted_python_implementations=(), should_run_r_tests=False): """ @@ -43,6 +43,8 @@ def __init__(self, name, dependencies, source_file_regexes, build_profile_flags= filename strings. :param build_profile_flags: A set of profile flags that should be passed to Maven or SBT in order to build and test this module (e.g. '-PprofileName'). + :param environ: A dict of environment variables that should be set when files in this + module are changed. :param sbt_test_goals: A set of SBT test goals for testing this module. :param python_test_goals: A set of Python test goals for testing this module. :param blacklisted_python_implementations: A set of Python implementations that are not @@ -55,6 +57,7 @@ def __init__(self, name, dependencies, source_file_regexes, build_profile_flags= self.source_file_prefixes = source_file_regexes self.sbt_test_goals = sbt_test_goals self.build_profile_flags = build_profile_flags + self.environ = environ self.python_test_goals = python_test_goals self.blacklisted_python_implementations = blacklisted_python_implementations self.should_run_r_tests = should_run_r_tests @@ -126,15 +129,22 @@ def contains_file(self, filename): ) +# Don't set the dependencies because changes in other modules should not trigger Kinesis tests. +# Kinesis tests depends on external Amazon kinesis service. We should run these tests only when +# files in streaming_kinesis_asl are changed, so that if Kinesis experiences an outage, we don't +# fail other PRs. streaming_kinesis_asl = Module( name="kinesis-asl", - dependencies=[streaming], + dependencies=[], source_file_regexes=[ "extras/kinesis-asl/", ], build_profile_flags=[ "-Pkinesis-asl", ], + environ={ + "ENABLE_KINESIS_TESTS": "1" + }, sbt_test_goals=[ "kinesis-asl/test", ] @@ -313,7 +323,7 @@ def contains_file(self, filename): "pyspark.mllib.evaluation", "pyspark.mllib.feature", "pyspark.mllib.fpm", - "pyspark.mllib.linalg", + "pyspark.mllib.linalg.__init__", "pyspark.mllib.random", "pyspark.mllib.recommendation", "pyspark.mllib.regression", diff --git a/docs/configuration.md b/docs/configuration.md index 200f3cd212e46..fd236137cb96e 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -203,7 +203,7 @@ Apart from these, the following properties are also available, and may be useful spark.driver.extraClassPath (none) - Extra classpath entries to append to the classpath of the driver. + Extra classpath entries to prepend to the classpath of the driver.
Note: In client mode, this config must not be set through the SparkConf directly in your application, because the driver JVM has already started at that point. @@ -250,7 +250,7 @@ Apart from these, the following properties are also available, and may be useful spark.executor.extraClassPath (none) - Extra classpath entries to append to the classpath of executors. This exists primarily for + Extra classpath entries to prepend to the classpath of executors. This exists primarily for backwards-compatibility with older versions of Spark. Users typically should not need to set this option. diff --git a/docs/mllib-evaluation-metrics.md b/docs/mllib-evaluation-metrics.md new file mode 100644 index 0000000000000..4ca0bb06b26a6 --- /dev/null +++ b/docs/mllib-evaluation-metrics.md @@ -0,0 +1,1497 @@ +--- +layout: global +title: Evaluation Metrics - MLlib +displayTitle: MLlib - Evaluation Metrics +--- + +* Table of contents +{:toc} + +Spark's MLlib comes with a number of machine learning algorithms that can be used to learn from and make predictions +on data. When these algorithms are applied to build machine learning models, there is a need to evaluate the performance +of the model on some criteria, which depends on the application and its requirements. Spark's MLlib also provides a +suite of metrics for the purpose of evaluating the performance of machine learning models. + +Specific machine learning algorithms fall under broader types of machine learning applications like classification, +regression, clustering, etc. Each of these types have well established metrics for performance evaluation and those +metrics that are currently available in Spark's MLlib are detailed in this section. + +## Classification model evaluation + +While there are many different types of classification algorithms, the evaluation of classification models all share +similar principles. In a [supervised classification problem](https://en.wikipedia.org/wiki/Statistical_classification), +there exists a true output and a model-generated predicted output for each data point. For this reason, the results for +each data point can be assigned to one of four categories: + +* True Positive (TP) - label is positive and prediction is also positive +* True Negative (TN) - label is negative and prediction is also negative +* False Positive (FP) - label is negative but prediction is positive +* False Negative (FN) - label is positive but prediction is negative + +These four numbers are the building blocks for most classifier evaluation metrics. A fundamental point when considering +classifier evaluation is that pure accuracy (i.e. was the prediction correct or incorrect) is not generally a good metric. The +reason for this is because a dataset may be highly unbalanced. For example, if a model is designed to predict fraud from +a dataset where 95% of the data points are _not fraud_ and 5% of the data points are _fraud_, then a naive classifier +that predicts _not fraud_, regardless of input, will be 95% accurate. For this reason, metrics like +[precision and recall](https://en.wikipedia.org/wiki/Precision_and_recall) are typically used because they take into +account the *type* of error. In most applications there is some desired balance between precision and recall, which can +be captured by combining the two into a single metric, called the [F-measure](https://en.wikipedia.org/wiki/F1_score). + +### Binary classification + +[Binary classifiers](https://en.wikipedia.org/wiki/Binary_classification) are used to separate the elements of a given +dataset into one of two possible groups (e.g. fraud or not fraud) and is a special case of multiclass classification. +Most binary classification metrics can be generalized to multiclass classification metrics. + +#### Threshold tuning + +It is import to understand that many classification models actually output a "score" (often times a probability) for +each class, where a higher score indicates higher likelihood. In the binary case, the model may output a probability for +each class: $P(Y=1|X)$ and $P(Y=0|X)$. Instead of simply taking the higher probability, there may be some cases where +the model might need to be tuned so that it only predicts a class when the probability is very high (e.g. only block a +credit card transaction if the model predicts fraud with >90% probability). Therefore, there is a prediction *threshold* +which determines what the predicted class will be based on the probabilities that the model outputs. + +Tuning the prediction threshold will change the precision and recall of the model and is an important part of model +optimization. In order to visualize how precision, recall, and other metrics change as a function of the threshold it is +common practice to plot competing metrics against one another, parameterized by threshold. A P-R curve plots (precision, +recall) points for different threshold values, while a +[receiver operating characteristic](https://en.wikipedia.org/wiki/Receiver_operating_characteristic), or ROC, curve +plots (recall, false positive rate) points. + +**Available metrics** + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
MetricDefinition
Precision (Postive Predictive Value)$PPV=\frac{TP}{TP + FP}$
Recall (True Positive Rate)$TPR=\frac{TP}{P}=\frac{TP}{TP + FN}$
F-measure$F(\beta) = \left(1 + \beta^2\right) \cdot \left(\frac{PPV \cdot TPR} + {\beta^2 \cdot PPV + TPR}\right)$
Receiver Operating Characteristic (ROC)$FPR(T)=\int^\infty_{T} P_0(T)\,dT \\ TPR(T)=\int^\infty_{T} P_1(T)\,dT$
Area Under ROC Curve$AUROC=\int^1_{0} \frac{TP}{P} d\left(\frac{FP}{N}\right)$
Area Under Precision-Recall Curve$AUPRC=\int^1_{0} \frac{TP}{TP+FP} d\left(\frac{TP}{P}\right)$
+ + +**Examples** + +
+The following code snippets illustrate how to load a sample dataset, train a binary classification algorithm on the +data, and evaluate the performance of the algorithm by several binary evaluation metrics. + +
+ +{% highlight scala %} +import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS +import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.util.MLUtils + +// Load training data in LIBSVM format +val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_binary_classification_data.txt") + +// Split data into training (60%) and test (40%) +val Array(training, test) = data.randomSplit(Array(0.6, 0.4), seed = 11L) +training.cache() + +// Run training algorithm to build the model +val model = new LogisticRegressionWithLBFGS() + .setNumClasses(2) + .run(training) + +// Clear the prediction threshold so the model will return probabilities +model.clearThreshold + +// Compute raw scores on the test set +val predictionAndLabels = test.map { case LabeledPoint(label, features) => + val prediction = model.predict(features) + (prediction, label) +} + +// Instantiate metrics object +val metrics = new BinaryClassificationMetrics(predictionAndLabels) + +// Precision by threshold +val precision = metrics.precisionByThreshold +precision.foreach { case (t, p) => + println(s"Threshold: $t, Precision: $p") +} + +// Recall by threshold +val recall = metrics.precisionByThreshold +recall.foreach { case (t, r) => + println(s"Threshold: $t, Recall: $r") +} + +// Precision-Recall Curve +val PRC = metrics.pr + +// F-measure +val f1Score = metrics.fMeasureByThreshold +f1Score.foreach { case (t, f) => + println(s"Threshold: $t, F-score: $f, Beta = 1") +} + +val beta = 0.5 +val fScore = metrics.fMeasureByThreshold(beta) +f1Score.foreach { case (t, f) => + println(s"Threshold: $t, F-score: $f, Beta = 0.5") +} + +// AUPRC +val auPRC = metrics.areaUnderPR +println("Area under precision-recall curve = " + auPRC) + +// Compute thresholds used in ROC and PR curves +val thresholds = precision.map(_._1) + +// ROC Curve +val roc = metrics.roc + +// AUROC +val auROC = metrics.areaUnderROC +println("Area under ROC = " + auROC) + +{% endhighlight %} + +
+ +
+ +{% highlight java %} +import scala.Tuple2; + +import org.apache.spark.api.java.*; +import org.apache.spark.rdd.RDD; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.mllib.classification.LogisticRegressionModel; +import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS; +import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.util.MLUtils; +import org.apache.spark.SparkConf; +import org.apache.spark.SparkContext; + +public class BinaryClassification { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("Binary Classification Metrics"); + SparkContext sc = new SparkContext(conf); + String path = "data/mllib/sample_binary_classification_data.txt"; + JavaRDD data = MLUtils.loadLibSVMFile(sc, path).toJavaRDD(); + + // Split initial RDD into two... [60% training data, 40% testing data]. + JavaRDD[] splits = data.randomSplit(new double[] {0.6, 0.4}, 11L); + JavaRDD training = splits[0].cache(); + JavaRDD test = splits[1]; + + // Run training algorithm to build the model. + final LogisticRegressionModel model = new LogisticRegressionWithLBFGS() + .setNumClasses(2) + .run(training.rdd()); + + // Clear the prediction threshold so the model will return probabilities + model.clearThreshold(); + + // Compute raw scores on the test set. + JavaRDD> predictionAndLabels = test.map( + new Function>() { + public Tuple2 call(LabeledPoint p) { + Double prediction = model.predict(p.features()); + return new Tuple2(prediction, p.label()); + } + } + ); + + // Get evaluation metrics. + BinaryClassificationMetrics metrics = new BinaryClassificationMetrics(predictionAndLabels.rdd()); + + // Precision by threshold + JavaRDD> precision = metrics.precisionByThreshold().toJavaRDD(); + System.out.println("Precision by threshold: " + precision.toArray()); + + // Recall by threshold + JavaRDD> recall = metrics.recallByThreshold().toJavaRDD(); + System.out.println("Recall by threshold: " + recall.toArray()); + + // F Score by threshold + JavaRDD> f1Score = metrics.fMeasureByThreshold().toJavaRDD(); + System.out.println("F1 Score by threshold: " + f1Score.toArray()); + + JavaRDD> f2Score = metrics.fMeasureByThreshold(2.0).toJavaRDD(); + System.out.println("F2 Score by threshold: " + f2Score.toArray()); + + // Precision-recall curve + JavaRDD> prc = metrics.pr().toJavaRDD(); + System.out.println("Precision-recall curve: " + prc.toArray()); + + // Thresholds + JavaRDD thresholds = precision.map( + new Function, Double>() { + public Double call (Tuple2 t) { + return new Double(t._1().toString()); + } + } + ); + + // ROC Curve + JavaRDD> roc = metrics.roc().toJavaRDD(); + System.out.println("ROC curve: " + roc.toArray()); + + // AUPRC + System.out.println("Area under precision-recall curve = " + metrics.areaUnderPR()); + + // AUROC + System.out.println("Area under ROC = " + metrics.areaUnderROC()); + + // Save and load model + model.save(sc, "myModelPath"); + LogisticRegressionModel sameModel = LogisticRegressionModel.load(sc, "myModelPath"); + } +} + +{% endhighlight %} + +
+ +
+ +{% highlight python %} +from pyspark.mllib.classification import LogisticRegressionWithLBFGS +from pyspark.mllib.evaluation import BinaryClassificationMetrics +from pyspark.mllib.regression import LabeledPoint +from pyspark.mllib.util import MLUtils + +# Several of the methods available in scala are currently missing from pyspark + +# Load training data in LIBSVM format +data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_binary_classification_data.txt") + +# Split data into training (60%) and test (40%) +training, test = data.randomSplit([0.6, 0.4], seed = 11L) +training.cache() + +# Run training algorithm to build the model +model = LogisticRegressionWithLBFGS.train(training) + +# Compute raw scores on the test set +predictionAndLabels = test.map(lambda lp: (float(model.predict(lp.features)), lp.label)) + +# Instantiate metrics object +metrics = BinaryClassificationMetrics(predictionAndLabels) + +# Area under precision-recall curve +print "Area under PR = %s" % metrics.areaUnderPR + +# Area under ROC curve +print "Area under ROC = %s" % metrics.areaUnderROC + +{% endhighlight %} + +
+
+ + +### Multiclass classification + +A [multiclass classification](https://en.wikipedia.org/wiki/Multiclass_classification) describes a classification +problem where there are $M \gt 2$ possible labels for each data point (the case where $M=2$ is the binary +classification problem). For example, classifying handwriting samples to the digits 0 to 9, having 10 possible classes. + +For multiclass metrics, the notion of positives and negatives is slightly different. Predictions and labels can still +be positive or negative, but they must be considered under the context of a particular class. Each label and prediction +take on the value of one of the multiple classes and so they are said to be positive for their particular class and negative +for all other classes. So, a true positive occurs whenever the prediction and the label match, while a true negative +occurs when neither the prediction nor the label take on the value of a given class. By this convention, there can be +multiple true negatives for a given data sample. The extension of false negatives and false positives from the former +definitions of positive and negative labels is straightforward. + +#### Label based metrics + +Opposed to binary classification where there are only two possible labels, multiclass classification problems have many +possible labels and so the concept of label-based metrics is introduced. Overall precision measures precision across all +labels - the number of times any class was predicted correctly (true positives) normalized by the number of data +points. Precision by label considers only one class, and measures the number of time a specific label was predicted +correctly normalized by the number of times that label appears in the output. + +**Available metrics** + +Define the class, or label, set as + +$$L = \{\ell_0, \ell_1, \ldots, \ell_{M-1} \} $$ + +The true output vector $\mathbf{y}$ consists of $N$ elements + +$$\mathbf{y}_0, \mathbf{y}_1, \ldots, \mathbf{y}_{N-1} \in L $$ + +A multiclass prediction algorithm generates a prediction vector $\hat{\mathbf{y}}$ of $N$ elements + +$$\hat{\mathbf{y}}_0, \hat{\mathbf{y}}_1, \ldots, \hat{\mathbf{y}}_{N-1} \in L $$ + +For this section, a modified delta function $\hat{\delta}(x)$ will prove useful + +$$\hat{\delta}(x) = \begin{cases}1 & \text{if $x = 0$}, \\ 0 & \text{otherwise}.\end{cases}$$ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
MetricDefinition
Confusion Matrix + $C_{ij} = \sum_{k=0}^{N-1} \hat{\delta}(\mathbf{y}_k-\ell_i) \cdot \hat{\delta}(\hat{\mathbf{y}}_k - \ell_j)\\ \\ + \left( \begin{array}{ccc} + \sum_{k=0}^{N-1} \hat{\delta}(\mathbf{y}_k-\ell_1) \cdot \hat{\delta}(\hat{\mathbf{y}}_k - \ell_1) & \ldots & + \sum_{k=0}^{N-1} \hat{\delta}(\mathbf{y}_k-\ell_1) \cdot \hat{\delta}(\hat{\mathbf{y}}_k - \ell_N) \\ + \vdots & \ddots & \vdots \\ + \sum_{k=0}^{N-1} \hat{\delta}(\mathbf{y}_k-\ell_N) \cdot \hat{\delta}(\hat{\mathbf{y}}_k - \ell_1) & \ldots & + \sum_{k=0}^{N-1} \hat{\delta}(\mathbf{y}_k-\ell_N) \cdot \hat{\delta}(\hat{\mathbf{y}}_k - \ell_N) + \end{array} \right)$ +
Overall Precision$PPV = \frac{TP}{TP + FP} = \frac{1}{N}\sum_{i=0}^{N-1} \hat{\delta}\left(\hat{\mathbf{y}}_i - + \mathbf{y}_i\right)$
Overall Recall$TPR = \frac{TP}{TP + FN} = \frac{1}{N}\sum_{i=0}^{N-1} \hat{\delta}\left(\hat{\mathbf{y}}_i - + \mathbf{y}_i\right)$
Overall F1-measure$F1 = 2 \cdot \left(\frac{PPV \cdot TPR} + {PPV + TPR}\right)$
Precision by label$PPV(\ell) = \frac{TP}{TP + FP} = + \frac{\sum_{i=0}^{N-1} \hat{\delta}(\hat{\mathbf{y}}_i - \ell) \cdot \hat{\delta}(\mathbf{y}_i - \ell)} + {\sum_{i=0}^{N-1} \hat{\delta}(\hat{\mathbf{y}}_i - \ell)}$
Recall by label$TPR(\ell)=\frac{TP}{P} = + \frac{\sum_{i=0}^{N-1} \hat{\delta}(\hat{\mathbf{y}}_i - \ell) \cdot \hat{\delta}(\mathbf{y}_i - \ell)} + {\sum_{i=0}^{N-1} \hat{\delta}(\mathbf{y}_i - \ell)}$
F-measure by label$F(\beta, \ell) = \left(1 + \beta^2\right) \cdot \left(\frac{PPV(\ell) \cdot TPR(\ell)} + {\beta^2 \cdot PPV(\ell) + TPR(\ell)}\right)$
Weighted precision$PPV_{w}= \frac{1}{N} \sum\nolimits_{\ell \in L} PPV(\ell) + \cdot \sum_{i=0}^{N-1} \hat{\delta}(\mathbf{y}_i-\ell)$
Weighted recall$TPR_{w}= \frac{1}{N} \sum\nolimits_{\ell \in L} TPR(\ell) + \cdot \sum_{i=0}^{N-1} \hat{\delta}(\mathbf{y}_i-\ell)$
Weighted F-measure$F_{w}(\beta)= \frac{1}{N} \sum\nolimits_{\ell \in L} F(\beta, \ell) + \cdot \sum_{i=0}^{N-1} \hat{\delta}(\mathbf{y}_i-\ell)$
+ +**Examples** + +
+The following code snippets illustrate how to load a sample dataset, train a multiclass classification algorithm on +the data, and evaluate the performance of the algorithm by several multiclass classification evaluation metrics. + +
+ +{% highlight scala %} +import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS +import org.apache.spark.mllib.evaluation.MulticlassMetrics +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.util.MLUtils + +// Load training data in LIBSVM format +val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_multiclass_classification_data.txt") + +// Split data into training (60%) and test (40%) +val Array(training, test) = data.randomSplit(Array(0.6, 0.4), seed = 11L) +training.cache() + +// Run training algorithm to build the model +val model = new LogisticRegressionWithLBFGS() + .setNumClasses(3) + .run(training) + +// Compute raw scores on the test set +val predictionAndLabels = test.map { case LabeledPoint(label, features) => + val prediction = model.predict(features) + (prediction, label) +} + +// Instantiate metrics object +val metrics = new MulticlassMetrics(predictionAndLabels) + +// Confusion matrix +println("Confusion matrix:") +println(metrics.confusionMatrix) + +// Overall Statistics +val precision = metrics.precision +val recall = metrics.recall // same as true positive rate +val f1Score = metrics.fMeasure +println("Summary Statistics") +println(s"Precision = $precision") +println(s"Recall = $recall") +println(s"F1 Score = $f1Score") + +// Precision by label +val labels = metrics.labels +labels.foreach { l => + println(s"Precision($l) = " + metrics.precision(l)) +} + +// Recall by label +labels.foreach { l => + println(s"Recall($l) = " + metrics.recall(l)) +} + +// False positive rate by label +labels.foreach { l => + println(s"FPR($l) = " + metrics.falsePositiveRate(l)) +} + +// F-measure by label +labels.foreach { l => + println(s"F1-Score($l) = " + metrics.fMeasure(l)) +} + +// Weighted stats +println(s"Weighted precision: ${metrics.weightedPrecision}") +println(s"Weighted recall: ${metrics.weightedRecall}") +println(s"Weighted F1 score: ${metrics.weightedFMeasure}") +println(s"Weighted false positive rate: ${metrics.weightedFalsePositiveRate}") + +{% endhighlight %} + +
+ +
+ +{% highlight java %} +import scala.Tuple2; + +import org.apache.spark.api.java.*; +import org.apache.spark.rdd.RDD; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.mllib.classification.LogisticRegressionModel; +import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS; +import org.apache.spark.mllib.evaluation.MulticlassMetrics; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.util.MLUtils; +import org.apache.spark.mllib.linalg.Matrix; +import org.apache.spark.SparkConf; +import org.apache.spark.SparkContext; + +public class MulticlassClassification { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("Multiclass Classification Metrics"); + SparkContext sc = new SparkContext(conf); + String path = "data/mllib/sample_multiclass_classification_data.txt"; + JavaRDD data = MLUtils.loadLibSVMFile(sc, path).toJavaRDD(); + + // Split initial RDD into two... [60% training data, 40% testing data]. + JavaRDD[] splits = data.randomSplit(new double[] {0.6, 0.4}, 11L); + JavaRDD training = splits[0].cache(); + JavaRDD test = splits[1]; + + // Run training algorithm to build the model. + final LogisticRegressionModel model = new LogisticRegressionWithLBFGS() + .setNumClasses(3) + .run(training.rdd()); + + // Compute raw scores on the test set. + JavaRDD> predictionAndLabels = test.map( + new Function>() { + public Tuple2 call(LabeledPoint p) { + Double prediction = model.predict(p.features()); + return new Tuple2(prediction, p.label()); + } + } + ); + + // Get evaluation metrics. + MulticlassMetrics metrics = new MulticlassMetrics(predictionAndLabels.rdd()); + + // Confusion matrix + Matrix confusion = metrics.confusionMatrix(); + System.out.println("Confusion matrix: \n" + confusion); + + // Overall statistics + System.out.println("Precision = " + metrics.precision()); + System.out.println("Recall = " + metrics.recall()); + System.out.println("F1 Score = " + metrics.fMeasure()); + + // Stats by labels + for (int i = 0; i < metrics.labels().length; i++) { + System.out.format("Class %f precision = %f\n", metrics.labels()[i], metrics.precision(metrics.labels()[i])); + System.out.format("Class %f recall = %f\n", metrics.labels()[i], metrics.recall(metrics.labels()[i])); + System.out.format("Class %f F1 score = %f\n", metrics.labels()[i], metrics.fMeasure(metrics.labels()[i])); + } + + //Weighted stats + System.out.format("Weighted precision = %f\n", metrics.weightedPrecision()); + System.out.format("Weighted recall = %f\n", metrics.weightedRecall()); + System.out.format("Weighted F1 score = %f\n", metrics.weightedFMeasure()); + System.out.format("Weighted false positive rate = %f\n", metrics.weightedFalsePositiveRate()); + + // Save and load model + model.save(sc, "myModelPath"); + LogisticRegressionModel sameModel = LogisticRegressionModel.load(sc, "myModelPath"); + } +} + +{% endhighlight %} + +
+ +
+ +{% highlight python %} +from pyspark.mllib.classification import LogisticRegressionWithLBFGS +from pyspark.mllib.util import MLUtils +from pyspark.mllib.evaluation import MulticlassMetrics + +# Load training data in LIBSVM format +data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_multiclass_classification_data.txt") + +# Split data into training (60%) and test (40%) +training, test = data.randomSplit([0.6, 0.4], seed = 11L) +training.cache() + +# Run training algorithm to build the model +model = LogisticRegressionWithLBFGS.train(training, numClasses=3) + +# Compute raw scores on the test set +predictionAndLabels = test.map(lambda lp: (float(model.predict(lp.features)), lp.label)) + +# Instantiate metrics object +metrics = MulticlassMetrics(predictionAndLabels) + +# Overall statistics +precision = metrics.precision() +recall = metrics.recall() +f1Score = metrics.fMeasure() +print "Summary Stats" +print "Precision = %s" % precision +print "Recall = %s" % recall +print "F1 Score = %s" % f1Score + +# Statistics by class +labels = data.map(lambda lp: lp.label).distinct().collect() +for label in sorted(labels): + print "Class %s precision = %s" % (label, metrics.precision(label)) + print "Class %s recall = %s" % (label, metrics.recall(label)) + print "Class %s F1 Measure = %s" % (label, metrics.fMeasure(label, beta=1.0)) + +# Weighted stats +print "Weighted recall = %s" % metrics.weightedRecall +print "Weighted precision = %s" % metrics.weightedPrecision +print "Weighted F(1) Score = %s" % metrics.weightedFMeasure() +print "Weighted F(0.5) Score = %s" % metrics.weightedFMeasure(beta=0.5) +print "Weighted false positive rate = %s" % metrics.weightedFalsePositiveRate +{% endhighlight %} + +
+
+ +### Multilabel classification + +A [multilabel classification](https://en.wikipedia.org/wiki/Multi-label_classification) problem involves mapping +each sample in a dataset to a set of class labels. In this type of classification problem, the labels are not +mutually exclusive. For example, when classifying a set of news articles into topics, a single article might be both +science and politics. + +Because the labels are not mutually exclusive, the predictions and true labels are now vectors of label *sets*, rather +than vectors of labels. Multilabel metrics, therefore, extend the fundamental ideas of precision, recall, etc. to +operations on sets. For example, a true positive for a given class now occurs when that class exists in the predicted +set and it exists in the true label set, for a specific data point. + +**Available metrics** + +Here we define a set $D$ of $N$ documents + +$$D = \left\{d_0, d_1, ..., d_{N-1}\right\}$$ + +Define $L_0, L_1, ..., L_{N-1}$ to be a family of label sets and $P_0, P_1, ..., P_{N-1}$ +to be a family of prediction sets where $L_i$ and $P_i$ are the label set and prediction set, respectively, that +correspond to document $d_i$. + +The set of all unique labels is given by + +$$L = \bigcup_{k=0}^{N-1} L_k$$ + +The following definition of indicator function $I_A(x)$ on a set $A$ will be necessary + +$$I_A(x) = \begin{cases}1 & \text{if $x \in A$}, \\ 0 & \text{otherwise}.\end{cases}$$ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
MetricDefinition
Precision$\frac{1}{N} \sum_{i=0}^{N-1} \frac{\left|P_i \cap L_i\right|}{\left|P_i\right|}$
Recall$\frac{1}{N} \sum_{i=0}^{N-1} \frac{\left|L_i \cap P_i\right|}{\left|L_i\right|}$
Accuracy + $\frac{1}{N} \sum_{i=0}^{N - 1} \frac{\left|L_i \cap P_i \right|} + {\left|L_i\right| + \left|P_i\right| - \left|L_i \cap P_i \right|}$ +
Precision by label$PPV(\ell)=\frac{TP}{TP + FP}= + \frac{\sum_{i=0}^{N-1} I_{P_i}(\ell) \cdot I_{L_i}(\ell)} + {\sum_{i=0}^{N-1} I_{P_i}(\ell)}$
Recall by label$TPR(\ell)=\frac{TP}{P}= + \frac{\sum_{i=0}^{N-1} I_{P_i}(\ell) \cdot I_{L_i}(\ell)} + {\sum_{i=0}^{N-1} I_{L_i}(\ell)}$
F1-measure by label$F1(\ell) = 2 + \cdot \left(\frac{PPV(\ell) \cdot TPR(\ell)} + {PPV(\ell) + TPR(\ell)}\right)$
Hamming Loss + $\frac{1}{N \cdot \left|L\right|} \sum_{i=0}^{N - 1} \left|L_i\right| + \left|P_i\right| - 2\left|L_i + \cap P_i\right|$ +
Subset Accuracy$\frac{1}{N} \sum_{i=0}^{N-1} I_{\{L_i\}}(P_i)$
F1 Measure$\frac{1}{N} \sum_{i=0}^{N-1} 2 \frac{\left|P_i \cap L_i\right|}{\left|P_i\right| \cdot \left|L_i\right|}$
Micro precision$\frac{TP}{TP + FP}=\frac{\sum_{i=0}^{N-1} \left|P_i \cap L_i\right|} + {\sum_{i=0}^{N-1} \left|P_i \cap L_i\right| + \sum_{i=0}^{N-1} \left|P_i - L_i\right|}$
Micro recall$\frac{TP}{TP + FN}=\frac{\sum_{i=0}^{N-1} \left|P_i \cap L_i\right|} + {\sum_{i=0}^{N-1} \left|P_i \cap L_i\right| + \sum_{i=0}^{N-1} \left|L_i - P_i\right|}$
Micro F1 Measure + $2 \cdot \frac{TP}{2 \cdot TP + FP + FN}=2 \cdot \frac{\sum_{i=0}^{N-1} \left|P_i \cap L_i\right|}{2 \cdot + \sum_{i=0}^{N-1} \left|P_i \cap L_i\right| + \sum_{i=0}^{N-1} \left|L_i - P_i\right| + \sum_{i=0}^{N-1} + \left|P_i - L_i\right|}$ +
+ +**Examples** + +The following code snippets illustrate how to evaluate the performance of a multilabel classifer. The examples +use the fake prediction and label data for multilabel classification that is shown below. + +Document predictions: + +* doc 0 - predict 0, 1 - class 0, 2 +* doc 1 - predict 0, 2 - class 0, 1 +* doc 2 - predict none - class 0 +* doc 3 - predict 2 - class 2 +* doc 4 - predict 2, 0 - class 2, 0 +* doc 5 - predict 0, 1, 2 - class 0, 1 +* doc 6 - predict 1 - class 1, 2 + +Predicted classes: + +* class 0 - doc 0, 1, 4, 5 (total 4) +* class 1 - doc 0, 5, 6 (total 3) +* class 2 - doc 1, 3, 4, 5 (total 4) + +True classes: + +* class 0 - doc 0, 1, 2, 4, 5 (total 5) +* class 1 - doc 1, 5, 6 (total 3) +* class 2 - doc 0, 3, 4, 6 (total 4) + +
+ +
+ +{% highlight scala %} +import org.apache.spark.mllib.evaluation.MultilabelMetrics +import org.apache.spark.rdd.RDD; + +val scoreAndLabels: RDD[(Array[Double], Array[Double])] = sc.parallelize( + Seq((Array(0.0, 1.0), Array(0.0, 2.0)), + (Array(0.0, 2.0), Array(0.0, 1.0)), + (Array(), Array(0.0)), + (Array(2.0), Array(2.0)), + (Array(2.0, 0.0), Array(2.0, 0.0)), + (Array(0.0, 1.0, 2.0), Array(0.0, 1.0)), + (Array(1.0), Array(1.0, 2.0))), 2) + +// Instantiate metrics object +val metrics = new MultilabelMetrics(scoreAndLabels) + +// Summary stats +println(s"Recall = ${metrics.recall}") +println(s"Precision = ${metrics.precision}") +println(s"F1 measure = ${metrics.f1Measure}") +println(s"Accuracy = ${metrics.accuracy}") + +// Individual label stats +metrics.labels.foreach(label => println(s"Class $label precision = ${metrics.precision(label)}")) +metrics.labels.foreach(label => println(s"Class $label recall = ${metrics.recall(label)}")) +metrics.labels.foreach(label => println(s"Class $label F1-score = ${metrics.f1Measure(label)}")) + +// Micro stats +println(s"Micro recall = ${metrics.microRecall}") +println(s"Micro precision = ${metrics.microPrecision}") +println(s"Micro F1 measure = ${metrics.microF1Measure}") + +// Hamming loss +println(s"Hamming loss = ${metrics.hammingLoss}") + +// Subset accuracy +println(s"Subset accuracy = ${metrics.subsetAccuracy}") + +{% endhighlight %} + +
+ +
+ +{% highlight java %} +import scala.Tuple2; + +import org.apache.spark.api.java.*; +import org.apache.spark.rdd.RDD; +import org.apache.spark.mllib.evaluation.MultilabelMetrics; +import org.apache.spark.SparkConf; +import java.util.Arrays; +import java.util.List; + +public class MultilabelClassification { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("Multilabel Classification Metrics"); + JavaSparkContext sc = new JavaSparkContext(conf); + + List> data = Arrays.asList( + new Tuple2(new double[]{0.0, 1.0}, new double[]{0.0, 2.0}), + new Tuple2(new double[]{0.0, 2.0}, new double[]{0.0, 1.0}), + new Tuple2(new double[]{}, new double[]{0.0}), + new Tuple2(new double[]{2.0}, new double[]{2.0}), + new Tuple2(new double[]{2.0, 0.0}, new double[]{2.0, 0.0}), + new Tuple2(new double[]{0.0, 1.0, 2.0}, new double[]{0.0, 1.0}), + new Tuple2(new double[]{1.0}, new double[]{1.0, 2.0}) + ); + JavaRDD> scoreAndLabels = sc.parallelize(data); + + // Instantiate metrics object + MultilabelMetrics metrics = new MultilabelMetrics(scoreAndLabels.rdd()); + + // Summary stats + System.out.format("Recall = %f\n", metrics.recall()); + System.out.format("Precision = %f\n", metrics.precision()); + System.out.format("F1 measure = %f\n", metrics.f1Measure()); + System.out.format("Accuracy = %f\n", metrics.accuracy()); + + // Stats by labels + for (int i = 0; i < metrics.labels().length - 1; i++) { + System.out.format("Class %1.1f precision = %f\n", metrics.labels()[i], metrics.precision(metrics.labels()[i])); + System.out.format("Class %1.1f recall = %f\n", metrics.labels()[i], metrics.recall(metrics.labels()[i])); + System.out.format("Class %1.1f F1 score = %f\n", metrics.labels()[i], metrics.f1Measure(metrics.labels()[i])); + } + + // Micro stats + System.out.format("Micro recall = %f\n", metrics.microRecall()); + System.out.format("Micro precision = %f\n", metrics.microPrecision()); + System.out.format("Micro F1 measure = %f\n", metrics.microF1Measure()); + + // Hamming loss + System.out.format("Hamming loss = %f\n", metrics.hammingLoss()); + + // Subset accuracy + System.out.format("Subset accuracy = %f\n", metrics.subsetAccuracy()); + + } +} + +{% endhighlight %} + +
+ +
+ +{% highlight python %} +from pyspark.mllib.evaluation import MultilabelMetrics + +scoreAndLabels = sc.parallelize([ + ([0.0, 1.0], [0.0, 2.0]), + ([0.0, 2.0], [0.0, 1.0]), + ([], [0.0]), + ([2.0], [2.0]), + ([2.0, 0.0], [2.0, 0.0]), + ([0.0, 1.0, 2.0], [0.0, 1.0]), + ([1.0], [1.0, 2.0])]) + +# Instantiate metrics object +metrics = MultilabelMetrics(scoreAndLabels) + +# Summary stats +print "Recall = %s" % metrics.recall() +print "Precision = %s" % metrics.precision() +print "F1 measure = %s" % metrics.f1Measure() +print "Accuracy = %s" % metrics.accuracy + +# Individual label stats +labels = scoreAndLabels.flatMap(lambda x: x[1]).distinct().collect() +for label in labels: + print "Class %s precision = %s" % (label, metrics.precision(label)) + print "Class %s recall = %s" % (label, metrics.recall(label)) + print "Class %s F1 Measure = %s" % (label, metrics.f1Measure(label)) + +# Micro stats +print "Micro precision = %s" % metrics.microPrecision +print "Micro recall = %s" % metrics.microRecall +print "Micro F1 measure = %s" % metrics.microF1Measure + +# Hamming loss +print "Hamming loss = %s" % metrics.hammingLoss + +# Subset accuracy +print "Subset accuracy = %s" % metrics.subsetAccuracy + +{% endhighlight %} + +
+
+ +### Ranking systems + +The role of a ranking algorithm (often thought of as a [recommender system](https://en.wikipedia.org/wiki/Recommender_system)) +is to return to the user a set of relevant items or documents based on some training data. The definition of relevance +may vary and is usually application specific. Ranking system metrics aim to quantify the effectiveness of these +rankings or recommendations in various contexts. Some metrics compare a set of recommended documents to a ground truth +set of relevant documents, while other metrics may incorporate numerical ratings explicitly. + +**Available metrics** + +A ranking system usually deals with a set of $M$ users + +$$U = \left\{u_0, u_1, ..., u_{M-1}\right\}$$ + +Each user ($u_i$) having a set of $N$ ground truth relevant documents + +$$D_i = \left\{d_0, d_1, ..., d_{N-1}\right\}$$ + +And a list of $Q$ recommended documents, in order of decreasing relevance + +$$R_i = \left[r_0, r_1, ..., r_{Q-1}\right]$$ + +The goal of the ranking system is to produce the most relevant set of documents for each user. The relevance of the +sets and the effectiveness of the algorithms can be measured using the metrics listed below. + +It is necessary to define a function which, provided a recommended document and a set of ground truth relevant +documents, returns a relevance score for the recommended document. + +$$rel_D(r) = \begin{cases}1 & \text{if $r \in D$}, \\ 0 & \text{otherwise}.\end{cases}$$ + + + + + + + + + + + + + + + + + + + + + + +
MetricDefinitionNotes
+ Precision at k + + $p(k)=\frac{1}{M} \sum_{i=0}^{M-1} {\frac{1}{k} \sum_{j=0}^{\text{min}(\left|D\right|, k) - 1} rel_{D_i}(R_i(j))}$ + + Precision at k is a measure of + how many of the first k recommended documents are in the set of true relevant documents averaged across all + users. In this metric, the order of the recommendations is not taken into account. +
Mean Average Precision + $MAP=\frac{1}{M} \sum_{i=0}^{M-1} {\frac{1}{\left|D_i\right|} \sum_{j=0}^{Q-1} \frac{rel_{D_i}(R_i(j))}{j + 1}}$ + + MAP is a measure of how + many of the recommended documents are in the set of true relevant documents, where the + order of the recommendations is taken into account (i.e. penalty for highly relevant documents is higher). +
Normalized Discounted Cumulative Gain + $NDCG(k)=\frac{1}{M} \sum_{i=0}^{M-1} {\frac{1}{IDCG(D_i, k)}\sum_{j=0}^{n-1} + \frac{rel_{D_i}(R_i(j))}{\text{ln}(j+1)}} \\ + \text{Where} \\ + \hspace{5 mm} n = \text{min}\left(\text{max}\left(|R_i|,|D_i|\right),k\right) \\ + \hspace{5 mm} IDCG(D, k) = \sum_{j=0}^{\text{min}(\left|D\right|, k) - 1} \frac{1}{\text{ln}(j+1)}$ + + NDCG at k is a + measure of how many of the first k recommended documents are in the set of true relevant documents averaged + across all users. In contrast to precision at k, this metric takes into account the order of the recommendations + (documents are assumed to be in order of decreasing relevance). +
+ +**Examples** + +The following code snippets illustrate how to load a sample dataset, train an alternating least squares recommendation +model on the data, and evaluate the performance of the recommender by several ranking metrics. A brief summary of the +methodology is provided below. + +MovieLens ratings are on a scale of 1-5: + + * 5: Must see + * 4: Will enjoy + * 3: It's okay + * 2: Fairly bad + * 1: Awful + +So we should not recommend a movie if the predicted rating is less than 3. +To map ratings to confidence scores, we use: + + * 5 -> 2.5 + * 4 -> 1.5 + * 3 -> 0.5 + * 2 -> -0.5 + * 1 -> -1.5. + +This mappings means unobserved entries are generally between It's okay and Fairly bad. The semantics of 0 in this +expanded world of non-positive weights are "the same as never having interacted at all." + +
+ +
+ +{% highlight scala %} +import org.apache.spark.mllib.evaluation.{RegressionMetrics, RankingMetrics} +import org.apache.spark.mllib.recommendation.{ALS, Rating} + +// Read in the ratings data +val ratings = sc.textFile("data/mllib/sample_movielens_data.txt").map { line => + val fields = line.split("::") + Rating(fields(0).toInt, fields(1).toInt, fields(2).toDouble - 2.5) +}.cache() + +// Map ratings to 1 or 0, 1 indicating a movie that should be recommended +val binarizedRatings = ratings.map(r => Rating(r.user, r.product, if (r.rating > 0) 1.0 else 0.0)).cache() + +// Summarize ratings +val numRatings = ratings.count() +val numUsers = ratings.map(_.user).distinct().count() +val numMovies = ratings.map(_.product).distinct().count() +println(s"Got $numRatings ratings from $numUsers users on $numMovies movies.") + +// Build the model +val numIterations = 10 +val rank = 10 +val lambda = 0.01 +val model = ALS.train(ratings, rank, numIterations, lambda) + +// Define a function to scale ratings from 0 to 1 +def scaledRating(r: Rating): Rating = { + val scaledRating = math.max(math.min(r.rating, 1.0), 0.0) + Rating(r.user, r.product, scaledRating) +} + +// Get sorted top ten predictions for each user and then scale from [0, 1] +val userRecommended = model.recommendProductsForUsers(10).map{ case (user, recs) => + (user, recs.map(scaledRating)) +} + +// Assume that any movie a user rated 3 or higher (which maps to a 1) is a relevant document +// Compare with top ten most relevant documents +val userMovies = binarizedRatings.groupBy(_.user) +val relevantDocuments = userMovies.join(userRecommended).map{ case (user, (actual, predictions)) => + (predictions.map(_.product), actual.filter(_.rating > 0.0).map(_.product).toArray) +} + +// Instantiate metrics object +val metrics = new RankingMetrics(relevantDocuments) + +// Precision at K +Array(1, 3, 5).foreach{ k => + println(s"Precision at $k = ${metrics.precisionAt(k)}") +} + +// Mean average precision +println(s"Mean average precision = ${metrics.meanAveragePrecision}") + +// Normalized discounted cumulative gain +Array(1, 3, 5).foreach{ k => + println(s"NDCG at $k = ${metrics.ndcgAt(k)}") +} + +// Get predictions for each data point +val allPredictions = model.predict(ratings.map(r => (r.user, r.product))).map(r => ((r.user, r.product), r.rating)) +val allRatings = ratings.map(r => ((r.user, r.product), r.rating)) +val predictionsAndLabels = allPredictions.join(allRatings).map{ case ((user, product), (predicted, actual)) => + (predicted, actual) +} + +// Get the RMSE using regression metrics +val regressionMetrics = new RegressionMetrics(predictionsAndLabels) +println(s"RMSE = ${regressionMetrics.rootMeanSquaredError}") + +// R-squared +println(s"R-squared = ${regressionMetrics.r2}") + +{% endhighlight %} + +
+ +
+ +{% highlight java %} +import scala.Tuple2; + +import org.apache.spark.api.java.*; +import org.apache.spark.rdd.RDD; +import org.apache.spark.mllib.recommendation.MatrixFactorizationModel; +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.function.Function; +import java.util.*; +import org.apache.spark.mllib.evaluation.RegressionMetrics; +import org.apache.spark.mllib.evaluation.RankingMetrics; +import org.apache.spark.mllib.recommendation.ALS; +import org.apache.spark.mllib.recommendation.Rating; + +// Read in the ratings data +public class Ranking { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("Ranking Metrics"); + JavaSparkContext sc = new JavaSparkContext(conf); + String path = "data/mllib/sample_movielens_data.txt"; + JavaRDD data = sc.textFile(path); + JavaRDD ratings = data.map( + new Function() { + public Rating call(String line) { + String[] parts = line.split("::"); + return new Rating(Integer.parseInt(parts[0]), Integer.parseInt(parts[1]), Double.parseDouble(parts[2]) - 2.5); + } + } + ); + ratings.cache(); + + // Train an ALS model + final MatrixFactorizationModel model = ALS.train(JavaRDD.toRDD(ratings), 10, 10, 0.01); + + // Get top 10 recommendations for every user and scale ratings from 0 to 1 + JavaRDD> userRecs = model.recommendProductsForUsers(10).toJavaRDD(); + JavaRDD> userRecsScaled = userRecs.map( + new Function, Tuple2>() { + public Tuple2 call(Tuple2 t) { + Rating[] scaledRatings = new Rating[t._2().length]; + for (int i = 0; i < scaledRatings.length; i++) { + double newRating = Math.max(Math.min(t._2()[i].rating(), 1.0), 0.0); + scaledRatings[i] = new Rating(t._2()[i].user(), t._2()[i].product(), newRating); + } + return new Tuple2(t._1(), scaledRatings); + } + } + ); + JavaPairRDD userRecommended = JavaPairRDD.fromJavaRDD(userRecsScaled); + + // Map ratings to 1 or 0, 1 indicating a movie that should be recommended + JavaRDD binarizedRatings = ratings.map( + new Function() { + public Rating call(Rating r) { + double binaryRating; + if (r.rating() > 0.0) { + binaryRating = 1.0; + } + else { + binaryRating = 0.0; + } + return new Rating(r.user(), r.product(), binaryRating); + } + } + ); + + // Group ratings by common user + JavaPairRDD> userMovies = binarizedRatings.groupBy( + new Function() { + public Object call(Rating r) { + return r.user(); + } + } + ); + + // Get true relevant documents from all user ratings + JavaPairRDD> userMoviesList = userMovies.mapValues( + new Function, List>() { + public List call(Iterable docs) { + List products = new ArrayList(); + for (Rating r : docs) { + if (r.rating() > 0.0) { + products.add(r.product()); + } + } + return products; + } + } + ); + + // Extract the product id from each recommendation + JavaPairRDD> userRecommendedList = userRecommended.mapValues( + new Function>() { + public List call(Rating[] docs) { + List products = new ArrayList(); + for (Rating r : docs) { + products.add(r.product()); + } + return products; + } + } + ); + JavaRDD, List>> relevantDocs = userMoviesList.join(userRecommendedList).values(); + + // Instantiate the metrics object + RankingMetrics metrics = RankingMetrics.of(relevantDocs); + + // Precision and NDCG at k + Integer[] kVector = {1, 3, 5}; + for (Integer k : kVector) { + System.out.format("Precision at %d = %f\n", k, metrics.precisionAt(k)); + System.out.format("NDCG at %d = %f\n", k, metrics.ndcgAt(k)); + } + + // Mean average precision + System.out.format("Mean average precision = %f\n", metrics.meanAveragePrecision()); + + // Evaluate the model using numerical ratings and regression metrics + JavaRDD> userProducts = ratings.map( + new Function>() { + public Tuple2 call(Rating r) { + return new Tuple2(r.user(), r.product()); + } + } + ); + JavaPairRDD, Object> predictions = JavaPairRDD.fromJavaRDD( + model.predict(JavaRDD.toRDD(userProducts)).toJavaRDD().map( + new Function, Object>>() { + public Tuple2, Object> call(Rating r){ + return new Tuple2, Object>( + new Tuple2(r.user(), r.product()), r.rating()); + } + } + )); + JavaRDD> ratesAndPreds = + JavaPairRDD.fromJavaRDD(ratings.map( + new Function, Object>>() { + public Tuple2, Object> call(Rating r){ + return new Tuple2, Object>( + new Tuple2(r.user(), r.product()), r.rating()); + } + } + )).join(predictions).values(); + + // Create regression metrics object + RegressionMetrics regressionMetrics = new RegressionMetrics(ratesAndPreds.rdd()); + + // Root mean squared error + System.out.format("RMSE = %f\n", regressionMetrics.rootMeanSquaredError()); + + // R-squared + System.out.format("R-squared = %f\n", regressionMetrics.r2()); + } +} + +{% endhighlight %} + +
+ +
+ +{% highlight python %} +from pyspark.mllib.recommendation import ALS, Rating +from pyspark.mllib.evaluation import RegressionMetrics, RankingMetrics + +# Read in the ratings data +lines = sc.textFile("data/mllib/sample_movielens_data.txt") + +def parseLine(line): + fields = line.split("::") + return Rating(int(fields[0]), int(fields[1]), float(fields[2]) - 2.5) +ratings = lines.map(lambda r: parseLine(r)) + +# Train a model on to predict user-product ratings +model = ALS.train(ratings, 10, 10, 0.01) + +# Get predicted ratings on all existing user-product pairs +testData = ratings.map(lambda p: (p.user, p.product)) +predictions = model.predictAll(testData).map(lambda r: ((r.user, r.product), r.rating)) + +ratingsTuple = ratings.map(lambda r: ((r.user, r.product), r.rating)) +scoreAndLabels = predictions.join(ratingsTuple).map(lambda tup: tup[1]) + +# Instantiate regression metrics to compare predicted and actual ratings +metrics = RegressionMetrics(scoreAndLabels) + +# Root mean sqaured error +print "RMSE = %s" % metrics.rootMeanSquaredError + +# R-squared +print "R-squared = %s" % metrics.r2 + +{% endhighlight %} + +
+
+ +## Regression model evaluation + +[Regression analysis](https://en.wikipedia.org/wiki/Regression_analysis) is used when predicting a continuous output +variable from a number of independent variables. + +**Available metrics** + + + + + + + + + + + + + + + + + + + + + + + + + + + +
MetricDefinition
Mean Squared Error (MSE)$MSE = \frac{\sum_{i=0}^{N-1} (\mathbf{y}_i - \hat{\mathbf{y}}_i)^2}{N}$
Root Mean Squared Error (RMSE)$RMSE = \sqrt{\frac{\sum_{i=0}^{N-1} (\mathbf{y}_i - \hat{\mathbf{y}}_i)^2}{N}}$
Mean Absoloute Error (MAE)$MAE=\sum_{i=0}^{N-1} \left|\mathbf{y}_i - \hat{\mathbf{y}}_i\right|$
Coefficient of Determination $(R^2)$$R^2=1 - \frac{MSE}{\text{VAR}(\mathbf{y}) \cdot (N-1)}=1-\frac{\sum_{i=0}^{N-1} + (\mathbf{y}_i - \hat{\mathbf{y}}_i)^2}{\sum_{i=0}^{N-1}(\mathbf{y}_i-\bar{\mathbf{y}})^2}$
Explained Variance$1 - \frac{\text{VAR}(\mathbf{y} - \mathbf{\hat{y}})}{\text{VAR}(\mathbf{y})}$
+ +**Examples** + +
+The following code snippets illustrate how to load a sample dataset, train a linear regression algorithm on the data, +and evaluate the performance of the algorithm by several regression metrics. + +
+ +{% highlight scala %} +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.regression.LinearRegressionModel +import org.apache.spark.mllib.regression.LinearRegressionWithSGD +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.evaluation.RegressionMetrics +import org.apache.spark.mllib.util.MLUtils + +// Load the data +val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_linear_regression_data.txt").cache() + +// Build the model +val numIterations = 100 +val model = LinearRegressionWithSGD.train(data, numIterations) + +// Get predictions +val valuesAndPreds = data.map{ point => + val prediction = model.predict(point.features) + (prediction, point.label) +} + +// Instantiate metrics object +val metrics = new RegressionMetrics(valuesAndPreds) + +// Squared error +println(s"MSE = ${metrics.meanSquaredError}") +println(s"RMSE = ${metrics.rootMeanSquaredError}") + +// R-squared +println(s"R-squared = ${metrics.r2}") + +// Mean absolute error +println(s"MAE = ${metrics.meanAbsoluteError}") + +// Explained variance +println(s"Explained variance = ${metrics.explainedVariance}") + +{% endhighlight %} + +
+ +
+ +{% highlight java %} +import scala.Tuple2; + +import org.apache.spark.api.java.*; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.regression.LinearRegressionModel; +import org.apache.spark.mllib.regression.LinearRegressionWithSGD; +import org.apache.spark.mllib.evaluation.RegressionMetrics; +import org.apache.spark.SparkConf; + +public class LinearRegression { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("Linear Regression Example"); + JavaSparkContext sc = new JavaSparkContext(conf); + + // Load and parse the data + String path = "data/mllib/sample_linear_regression_data.txt"; + JavaRDD data = sc.textFile(path); + JavaRDD parsedData = data.map( + new Function() { + public LabeledPoint call(String line) { + String[] parts = line.split(" "); + double[] v = new double[parts.length - 1]; + for (int i = 1; i < parts.length - 1; i++) + v[i - 1] = Double.parseDouble(parts[i].split(":")[1]); + return new LabeledPoint(Double.parseDouble(parts[0]), Vectors.dense(v)); + } + } + ); + parsedData.cache(); + + // Building the model + int numIterations = 100; + final LinearRegressionModel model = + LinearRegressionWithSGD.train(JavaRDD.toRDD(parsedData), numIterations); + + // Evaluate model on training examples and compute training error + JavaRDD> valuesAndPreds = parsedData.map( + new Function>() { + public Tuple2 call(LabeledPoint point) { + double prediction = model.predict(point.features()); + return new Tuple2(prediction, point.label()); + } + } + ); + + // Instantiate metrics object + RegressionMetrics metrics = new RegressionMetrics(valuesAndPreds.rdd()); + + // Squared error + System.out.format("MSE = %f\n", metrics.meanSquaredError()); + System.out.format("RMSE = %f\n", metrics.rootMeanSquaredError()); + + // R-squared + System.out.format("R Squared = %f\n", metrics.r2()); + + // Mean absolute error + System.out.format("MAE = %f\n", metrics.meanAbsoluteError()); + + // Explained variance + System.out.format("Explained Variance = %f\n", metrics.explainedVariance()); + + // Save and load model + model.save(sc.sc(), "myModelPath"); + LinearRegressionModel sameModel = LinearRegressionModel.load(sc.sc(), "myModelPath"); + } +} + +{% endhighlight %} + +
+ +
+ +{% highlight python %} +from pyspark.mllib.regression import LabeledPoint, LinearRegressionWithSGD +from pyspark.mllib.evaluation import RegressionMetrics +from pyspark.mllib.linalg import DenseVector + +# Load and parse the data +def parsePoint(line): + values = line.split() + return LabeledPoint(float(values[0]), DenseVector([float(x.split(':')[1]) for x in values[1:]])) + +data = sc.textFile("data/mllib/sample_linear_regression_data.txt") +parsedData = data.map(parsePoint) + +# Build the model +model = LinearRegressionWithSGD.train(parsedData) + +# Get predictions +valuesAndPreds = parsedData.map(lambda p: (float(model.predict(p.features)), p.label)) + +# Instantiate metrics object +metrics = RegressionMetrics(valuesAndPreds) + +# Squared Error +print "MSE = %s" % metrics.meanSquaredError +print "RMSE = %s" % metrics.rootMeanSquaredError + +# R-squared +print "R-squared = %s" % metrics.r2 + +# Mean absolute error +print "MAE = %s" % metrics.meanAbsoluteError + +# Explained variance +print "Explained variance = %s" % metrics.explainedVariance + +{% endhighlight %} + +
+
\ No newline at end of file diff --git a/docs/mllib-guide.md b/docs/mllib-guide.md index d2d1cc93fe006..eea864eacf7c4 100644 --- a/docs/mllib-guide.md +++ b/docs/mllib-guide.md @@ -48,6 +48,7 @@ This lists functionality included in `spark.mllib`, the main MLlib API. * [Feature extraction and transformation](mllib-feature-extraction.html) * [Frequent pattern mining](mllib-frequent-pattern-mining.html) * FP-growth +* [Evaluation Metrics](mllib-evaluation-metrics.html) * [Optimization (developer)](mllib-optimization.html) * stochastic gradient descent * limited-memory BFGS (L-BFGS) diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py index 7c83d68e7993e..ccf922d9371fb 100755 --- a/ec2/spark_ec2.py +++ b/ec2/spark_ec2.py @@ -242,7 +242,7 @@ def parse_args(): help="Number of EBS volumes to attach to each node as /vol[x]. " + "The volumes will be deleted when the instances terminate. " + "Only possible on EBS-backed AMIs. " + - "EBS volumes are only attached if --ebs-vol-size > 0." + + "EBS volumes are only attached if --ebs-vol-size > 0. " + "Only support up to 8 EBS volumes.") parser.add_option( "--placement-group", type="string", default=None, diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala index 0ff1b7ed0fd90..ca39358b75cb6 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala @@ -53,6 +53,8 @@ private class KinesisTestUtils( @volatile private var streamCreated = false + + @volatile private var _streamName: String = _ private lazy val kinesisClient = { @@ -115,21 +117,9 @@ private class KinesisTestUtils( shardIdToSeqNumbers.toMap } - def describeStream(streamNameToDescribe: String = streamName): Option[StreamDescription] = { - try { - val describeStreamRequest = new DescribeStreamRequest().withStreamName(streamNameToDescribe) - val desc = kinesisClient.describeStream(describeStreamRequest).getStreamDescription() - Some(desc) - } catch { - case rnfe: ResourceNotFoundException => - None - } - } - def deleteStream(): Unit = { try { - if (describeStream().nonEmpty) { - val deleteStreamRequest = new DeleteStreamRequest() + if (streamCreated) { kinesisClient.deleteStream(streamName) } } catch { @@ -149,6 +139,17 @@ private class KinesisTestUtils( } } + private def describeStream(streamNameToDescribe: String): Option[StreamDescription] = { + try { + val describeStreamRequest = new DescribeStreamRequest().withStreamName(streamNameToDescribe) + val desc = kinesisClient.describeStream(describeStreamRequest).getStreamDescription() + Some(desc) + } catch { + case rnfe: ResourceNotFoundException => + None + } + } + private def findNonExistentStreamName(): String = { var testStreamName: String = null do { diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala index b2e2a4246dbd5..e81fb11e5959f 100644 --- a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala +++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala @@ -17,10 +17,10 @@ package org.apache.spark.streaming.kinesis -import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} +import org.scalatest.BeforeAndAfterAll import org.apache.spark.storage.{BlockId, BlockManager, StorageLevel, StreamBlockId} -import org.apache.spark.{SparkConf, SparkContext, SparkException, SparkFunSuite} +import org.apache.spark.{SparkConf, SparkContext, SparkException} class KinesisBackedBlockRDDSuite extends KinesisFunSuite with BeforeAndAfterAll { @@ -65,6 +65,9 @@ class KinesisBackedBlockRDDSuite extends KinesisFunSuite with BeforeAndAfterAll } override def afterAll(): Unit = { + if (testUtils != null) { + testUtils.deleteStream() + } if (sc != null) { sc.stop() } diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala index 4992b041765e9..b88c9c6478d56 100644 --- a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala +++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala @@ -59,7 +59,7 @@ class KinesisStreamSuite extends KinesisFunSuite } } - ignore("KinesisUtils API") { + test("KinesisUtils API") { ssc = new StreamingContext(sc, Seconds(1)) // Tests the API, does not actually test data receiving val kinesisStream1 = KinesisUtils.createStream(ssc, "mySparkStream", @@ -83,16 +83,16 @@ class KinesisStreamSuite extends KinesisFunSuite * you must have AWS credentials available through the default AWS provider chain, * and you have to set the system environment variable RUN_KINESIS_TESTS=1 . */ - ignore("basic operation") { + testIfEnabled("basic operation") { val kinesisTestUtils = new KinesisTestUtils() try { kinesisTestUtils.createStream() ssc = new StreamingContext(sc, Seconds(1)) - val aWSCredentials = KinesisTestUtils.getAWSCredentials() + val awsCredentials = KinesisTestUtils.getAWSCredentials() val stream = KinesisUtils.createStream(ssc, kinesisAppName, kinesisTestUtils.streamName, kinesisTestUtils.endpointUrl, kinesisTestUtils.regionName, InitialPositionInStream.LATEST, Seconds(10), StorageLevel.MEMORY_ONLY, - aWSCredentials.getAWSAccessKeyId, aWSCredentials.getAWSSecretKey) + awsCredentials.getAWSAccessKeyId, awsCredentials.getAWSSecretKey) val collected = new mutable.HashSet[Int] with mutable.SynchronizedSet[Int] stream.map { bytes => new String(bytes).toInt }.foreachRDD { rdd => diff --git a/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala b/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala index cfcf7244eaed5..2ca60d51f8331 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala @@ -127,28 +127,25 @@ object Pregel extends Logging { var prevG: Graph[VD, ED] = null var i = 0 while (activeMessages > 0 && i < maxIterations) { - // Receive the messages. Vertices that didn't get any messages do not appear in newVerts. - val newVerts = g.vertices.innerJoin(messages)(vprog).cache() - // Update the graph with the new vertices. + // Receive the messages and update the vertices. prevG = g - g = g.outerJoinVertices(newVerts) { (vid, old, newOpt) => newOpt.getOrElse(old) } - g.cache() + g = g.joinVertices(messages)(vprog).cache() val oldMessages = messages - // Send new messages. Vertices that didn't get any messages don't appear in newVerts, so don't - // get to send messages. We must cache messages so it can be materialized on the next line, - // allowing us to uncache the previous iteration. - messages = g.mapReduceTriplets(sendMsg, mergeMsg, Some((newVerts, activeDirection))).cache() - // The call to count() materializes `messages`, `newVerts`, and the vertices of `g`. This - // hides oldMessages (depended on by newVerts), newVerts (depended on by messages), and the - // vertices of prevG (depended on by newVerts, oldMessages, and the vertices of g). + // Send new messages, skipping edges where neither side received a message. We must cache + // messages so it can be materialized on the next line, allowing us to uncache the previous + // iteration. + messages = g.mapReduceTriplets( + sendMsg, mergeMsg, Some((oldMessages, activeDirection))).cache() + // The call to count() materializes `messages` and the vertices of `g`. This hides oldMessages + // (depended on by the vertices of g) and the vertices of prevG (depended on by oldMessages + // and the vertices of g). activeMessages = messages.count() logInfo("Pregel finished iteration " + i) // Unpersist the RDDs hidden by newly-materialized RDDs oldMessages.unpersist(blocking = false) - newVerts.unpersist(blocking = false) prevG.unpersistVertices(blocking = false) prevG.edges.unpersist(blocking = false) // count the iteration diff --git a/make-distribution.sh b/make-distribution.sh index cac7032bb2e87..4789b0e09cc8a 100755 --- a/make-distribution.sh +++ b/make-distribution.sh @@ -33,7 +33,7 @@ SPARK_HOME="$(cd "`dirname "$0"`"; pwd)" DISTDIR="$SPARK_HOME/dist" SPARK_TACHYON=false -TACHYON_VERSION="0.6.4" +TACHYON_VERSION="0.7.0" TACHYON_TGZ="tachyon-${TACHYON_VERSION}-bin.tar.gz" TACHYON_URL="https://github.com/amplab/tachyon/releases/download/v${TACHYON_VERSION}/${TACHYON_TGZ}" diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index fc0693f67cc2e..bc19bd6df894f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -25,7 +25,7 @@ import org.apache.spark.ml.{PredictionModel, Predictor} import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.tree.{DecisionTreeModel, RandomForestParams, TreeClassifierParams, TreeEnsembleModel} import org.apache.spark.ml.util.{Identifiable, MetadataUtils} -import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel} @@ -43,7 +43,7 @@ import org.apache.spark.sql.types.DoubleType */ @Experimental final class RandomForestClassifier(override val uid: String) - extends Predictor[Vector, RandomForestClassifier, RandomForestClassificationModel] + extends Classifier[Vector, RandomForestClassifier, RandomForestClassificationModel] with RandomForestParams with TreeClassifierParams { def this() = this(Identifiable.randomUID("rfc")) @@ -98,7 +98,7 @@ final class RandomForestClassifier(override val uid: String) val trees = RandomForest.run(oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed) .map(_.asInstanceOf[DecisionTreeClassificationModel]) - new RandomForestClassificationModel(trees) + new RandomForestClassificationModel(trees, numClasses) } override def copy(extra: ParamMap): RandomForestClassifier = defaultCopy(extra) @@ -125,8 +125,9 @@ object RandomForestClassifier { @Experimental final class RandomForestClassificationModel private[ml] ( override val uid: String, - private val _trees: Array[DecisionTreeClassificationModel]) - extends PredictionModel[Vector, RandomForestClassificationModel] + private val _trees: Array[DecisionTreeClassificationModel], + override val numClasses: Int) + extends ClassificationModel[Vector, RandomForestClassificationModel] with TreeEnsembleModel with Serializable { require(numTrees > 0, "RandomForestClassificationModel requires at least 1 tree.") @@ -135,8 +136,8 @@ final class RandomForestClassificationModel private[ml] ( * Construct a random forest classification model, with all trees weighted equally. * @param trees Component trees */ - def this(trees: Array[DecisionTreeClassificationModel]) = - this(Identifiable.randomUID("rfc"), trees) + def this(trees: Array[DecisionTreeClassificationModel], numClasses: Int) = + this(Identifiable.randomUID("rfc"), trees, numClasses) override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]] @@ -153,20 +154,20 @@ final class RandomForestClassificationModel private[ml] ( dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) } - override protected def predict(features: Vector): Double = { + override protected def predictRaw(features: Vector): Vector = { // TODO: When we add a generic Bagging class, handle transform there: SPARK-7128 // Classifies using majority votes. // Ignore the weights since all are 1.0 for now. - val votes = mutable.Map.empty[Int, Double] + val votes = new Array[Double](numClasses) _trees.view.foreach { tree => val prediction = tree.rootNode.predict(features).toInt - votes(prediction) = votes.getOrElse(prediction, 0.0) + 1.0 // 1.0 = weight + votes(prediction) = votes(prediction) + 1.0 // 1.0 = weight } - votes.maxBy(_._2)._1 + Vectors.dense(votes) } override def copy(extra: ParamMap): RandomForestClassificationModel = { - copyValues(new RandomForestClassificationModel(uid, _trees), extra) + copyValues(new RandomForestClassificationModel(uid, _trees, numClasses), extra) } override def toString: String = { @@ -185,7 +186,8 @@ private[ml] object RandomForestClassificationModel { def fromOld( oldModel: OldRandomForestModel, parent: RandomForestClassifier, - categoricalFeatures: Map[Int, Int]): RandomForestClassificationModel = { + categoricalFeatures: Map[Int, Int], + numClasses: Int): RandomForestClassificationModel = { require(oldModel.algo == OldAlgo.Classification, "Cannot convert RandomForestModel" + s" with algo=${oldModel.algo} (old API) to RandomForestClassificationModel (new API).") val newTrees = oldModel.trees.map { tree => @@ -193,6 +195,6 @@ private[ml] object RandomForestClassificationModel { DecisionTreeClassificationModel.fromOld(tree, null, categoricalFeatures) } val uid = if (parent != null) parent.uid else Identifiable.randomUID("rfc") - new RandomForestClassificationModel(uid, newTrees) + new RandomForestClassificationModel(uid, newTrees, numClasses) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala index 3825942795645..9c60d4084ec46 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala @@ -66,7 +66,6 @@ class OneHotEncoder(override val uid: String) extends Transformer def setOutputCol(value: String): this.type = set(outputCol, value) override def transformSchema(schema: StructType): StructType = { - val is = "_is_" val inputColName = $(inputCol) val outputColName = $(outputCol) @@ -79,17 +78,17 @@ class OneHotEncoder(override val uid: String) extends Transformer val outputAttrNames: Option[Array[String]] = inputAttr match { case nominal: NominalAttribute => if (nominal.values.isDefined) { - nominal.values.map(_.map(v => inputColName + is + v)) + nominal.values } else if (nominal.numValues.isDefined) { - nominal.numValues.map(n => Array.tabulate(n)(i => inputColName + is + i)) + nominal.numValues.map(n => Array.tabulate(n)(_.toString)) } else { None } case binary: BinaryAttribute => if (binary.values.isDefined) { - binary.values.map(_.map(v => inputColName + is + v)) + binary.values } else { - Some(Array.tabulate(2)(i => inputColName + is + i)) + Some(Array.tabulate(2)(_.toString)) } case _: NumericAttribute => throw new RuntimeException( @@ -123,7 +122,6 @@ class OneHotEncoder(override val uid: String) extends Transformer override def transform(dataset: DataFrame): DataFrame = { // schema transformation - val is = "_is_" val inputColName = $(inputCol) val outputColName = $(outputCol) val shouldDropLast = $(dropLast) @@ -142,7 +140,7 @@ class OneHotEncoder(override val uid: String) extends Transformer math.max(m0, m1) } ).toInt + 1 - val outputAttrNames = Array.tabulate(numAttrs)(i => inputColName + is + i) + val outputAttrNames = Array.tabulate(numAttrs)(_.toString) val filtered = if (shouldDropLast) outputAttrNames.dropRight(1) else outputAttrNames val outputAttrs: Array[Attribute] = filtered.map(name => BinaryAttribute.defaultAttr.withName(name)) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala index 0a95b1ee8de6e..d1726917e4517 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala @@ -17,6 +17,7 @@ package org.apache.spark.ml.feature +import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import scala.util.parsing.combinator.RegexParsers @@ -78,17 +79,33 @@ class RFormula(override val uid: String) extends Estimator[RFormulaModel] with R /** @group getParam */ def getFormula: String = $(formula) + /** Whether the formula specifies fitting an intercept. */ + private[ml] def hasIntercept: Boolean = { + require(parsedFormula.isDefined, "Must call setFormula() first.") + parsedFormula.get.hasIntercept + } + override def fit(dataset: DataFrame): RFormulaModel = { require(parsedFormula.isDefined, "Must call setFormula() first.") + val resolvedFormula = parsedFormula.get.resolve(dataset.schema) // StringType terms and terms representing interactions need to be encoded before assembly. // TODO(ekl) add support for feature interactions - var encoderStages = ArrayBuffer[PipelineStage]() - var tempColumns = ArrayBuffer[String]() - val encodedTerms = parsedFormula.get.terms.map { term => + val encoderStages = ArrayBuffer[PipelineStage]() + val tempColumns = ArrayBuffer[String]() + val takenNames = mutable.Set(dataset.columns: _*) + val encodedTerms = resolvedFormula.terms.map { term => dataset.schema(term) match { case column if column.dataType == StringType => val indexCol = term + "_idx_" + uid - val encodedCol = term + "_onehot_" + uid + val encodedCol = { + var tmp = term + while (takenNames.contains(tmp)) { + tmp += "_" + } + tmp + } + takenNames.add(indexCol) + takenNames.add(encodedCol) encoderStages += new StringIndexer().setInputCol(term).setOutputCol(indexCol) encoderStages += new OneHotEncoder().setInputCol(indexCol).setOutputCol(encodedCol) tempColumns += indexCol @@ -103,7 +120,7 @@ class RFormula(override val uid: String) extends Estimator[RFormulaModel] with R .setOutputCol($(featuresCol)) encoderStages += new ColumnPruner(tempColumns.toSet) val pipelineModel = new Pipeline(uid).setStages(encoderStages.toArray).fit(dataset) - copyValues(new RFormulaModel(uid, parsedFormula.get, pipelineModel).setParent(this)) + copyValues(new RFormulaModel(uid, resolvedFormula, pipelineModel).setParent(this)) } // optimistic schema; does not contain any ML attributes @@ -124,13 +141,13 @@ class RFormula(override val uid: String) extends Estimator[RFormulaModel] with R /** * :: Experimental :: * A fitted RFormula. Fitting is required to determine the factor levels of formula terms. - * @param parsedFormula a pre-parsed R formula. + * @param resolvedFormula the fitted R formula. * @param pipelineModel the fitted feature model, including factor to index mappings. */ @Experimental class RFormulaModel private[feature]( override val uid: String, - parsedFormula: ParsedRFormula, + resolvedFormula: ResolvedRFormula, pipelineModel: PipelineModel) extends Model[RFormulaModel] with RFormulaBase { @@ -144,8 +161,8 @@ class RFormulaModel private[feature]( val withFeatures = pipelineModel.transformSchema(schema) if (hasLabelCol(schema)) { withFeatures - } else if (schema.exists(_.name == parsedFormula.label)) { - val nullable = schema(parsedFormula.label).dataType match { + } else if (schema.exists(_.name == resolvedFormula.label)) { + val nullable = schema(resolvedFormula.label).dataType match { case _: NumericType | BooleanType => false case _ => true } @@ -158,12 +175,12 @@ class RFormulaModel private[feature]( } override def copy(extra: ParamMap): RFormulaModel = copyValues( - new RFormulaModel(uid, parsedFormula, pipelineModel)) + new RFormulaModel(uid, resolvedFormula, pipelineModel)) - override def toString: String = s"RFormulaModel(${parsedFormula})" + override def toString: String = s"RFormulaModel(${resolvedFormula})" private def transformLabel(dataset: DataFrame): DataFrame = { - val labelName = parsedFormula.label + val labelName = resolvedFormula.label if (hasLabelCol(dataset.schema)) { dataset } else if (dataset.schema.exists(_.name == labelName)) { @@ -207,26 +224,3 @@ private class ColumnPruner(columnsToPrune: Set[String]) extends Transformer { override def copy(extra: ParamMap): ColumnPruner = defaultCopy(extra) } - -/** - * Represents a parsed R formula. - */ -private[ml] case class ParsedRFormula(label: String, terms: Seq[String]) - -/** - * Limited implementation of R formula parsing. Currently supports: '~', '+'. - */ -private[ml] object RFormulaParser extends RegexParsers { - def term: Parser[String] = "([a-zA-Z]|\\.[a-zA-Z_])[a-zA-Z0-9._]*".r - - def expr: Parser[List[String]] = term ~ rep("+" ~> term) ^^ { case a ~ list => a :: list } - - def formula: Parser[ParsedRFormula] = - (term ~ "~" ~ expr) ^^ { case r ~ "~" ~ t => ParsedRFormula(r, t.distinct) } - - def parse(value: String): ParsedRFormula = parseAll(formula, value) match { - case Success(result, _) => result - case failure: NoSuccess => throw new IllegalArgumentException( - "Could not parse formula: " + value) - } -} diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormulaParser.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormulaParser.scala new file mode 100644 index 0000000000000..1ca3b92a7d92a --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormulaParser.scala @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import scala.util.parsing.combinator.RegexParsers + +import org.apache.spark.mllib.linalg.VectorUDT +import org.apache.spark.sql.types._ + +/** + * Represents a parsed R formula. + */ +private[ml] case class ParsedRFormula(label: ColumnRef, terms: Seq[Term]) { + /** + * Resolves formula terms into column names. A schema is necessary for inferring the meaning + * of the special '.' term. Duplicate terms will be removed during resolution. + */ + def resolve(schema: StructType): ResolvedRFormula = { + var includedTerms = Seq[String]() + terms.foreach { + case Dot => + includedTerms ++= simpleTypes(schema).filter(_ != label.value) + case ColumnRef(value) => + includedTerms :+= value + case Deletion(term: Term) => + term match { + case ColumnRef(value) => + includedTerms = includedTerms.filter(_ != value) + case Dot => + // e.g. "- .", which removes all first-order terms + val fromSchema = simpleTypes(schema) + includedTerms = includedTerms.filter(fromSchema.contains(_)) + case _: Deletion => + assert(false, "Deletion terms cannot be nested") + case _: Intercept => + } + case _: Intercept => + } + ResolvedRFormula(label.value, includedTerms.distinct) + } + + /** Whether this formula specifies fitting with an intercept term. */ + def hasIntercept: Boolean = { + var intercept = true + terms.foreach { + case Intercept(enabled) => + intercept = enabled + case Deletion(Intercept(enabled)) => + intercept = !enabled + case _ => + } + intercept + } + + // the dot operator excludes complex column types + private def simpleTypes(schema: StructType): Seq[String] = { + schema.fields.filter(_.dataType match { + case _: NumericType | StringType | BooleanType | _: VectorUDT => true + case _ => false + }).map(_.name) + } +} + +/** + * Represents a fully evaluated and simplified R formula. + */ +private[ml] case class ResolvedRFormula(label: String, terms: Seq[String]) + +/** + * R formula terms. See the R formula docs here for more information: + * http://stat.ethz.ch/R-manual/R-patched/library/stats/html/formula.html + */ +private[ml] sealed trait Term + +/* R formula reference to all available columns, e.g. "." in a formula */ +private[ml] case object Dot extends Term + +/* R formula reference to a column, e.g. "+ Species" in a formula */ +private[ml] case class ColumnRef(value: String) extends Term + +/* R formula intercept toggle, e.g. "+ 0" in a formula */ +private[ml] case class Intercept(enabled: Boolean) extends Term + +/* R formula deletion of a variable, e.g. "- Species" in a formula */ +private[ml] case class Deletion(term: Term) extends Term + +/** + * Limited implementation of R formula parsing. Currently supports: '~', '+', '-', '.'. + */ +private[ml] object RFormulaParser extends RegexParsers { + def intercept: Parser[Intercept] = + "([01])".r ^^ { case a => Intercept(a == "1") } + + def columnRef: Parser[ColumnRef] = + "([a-zA-Z]|\\.[a-zA-Z_])[a-zA-Z0-9._]*".r ^^ { case a => ColumnRef(a) } + + def term: Parser[Term] = intercept | columnRef | "\\.".r ^^ { case _ => Dot } + + def terms: Parser[List[Term]] = (term ~ rep("+" ~ term | "-" ~ term)) ^^ { + case op ~ list => list.foldLeft(List(op)) { + case (left, "+" ~ right) => left ++ Seq(right) + case (left, "-" ~ right) => left ++ Seq(Deletion(right)) + } + } + + def formula: Parser[ParsedRFormula] = + (columnRef ~ "~" ~ terms) ^^ { case r ~ "~" ~ t => ParsedRFormula(r, t) } + + def parse(value: String): ParsedRFormula = parseAll(formula, value) match { + case Success(result, _) => result + case failure: NoSuccess => throw new IllegalArgumentException( + "Could not parse formula: " + value) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala index 0b3af4747e693..248288ca73e99 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala @@ -50,7 +50,7 @@ class Tokenizer(override val uid: String) extends UnaryTransformer[String, Seq[S /** * :: Experimental :: * A regex based tokenizer that extracts tokens either by using the provided regex pattern to split - * the text (default) or repeatedly matching the regex (if `gaps` is true). + * the text (default) or repeatedly matching the regex (if `gaps` is false). * Optional parameters also allow filtering tokens using a minimal length. * It returns an array of strings that can be empty. */ diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala index 1ee080641e3e3..f5a022c31ed90 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala @@ -17,9 +17,10 @@ package org.apache.spark.ml.api.r +import org.apache.spark.ml.attribute._ import org.apache.spark.ml.feature.RFormula -import org.apache.spark.ml.classification.LogisticRegression -import org.apache.spark.ml.regression.LinearRegression +import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel} +import org.apache.spark.ml.regression.{LinearRegression, LinearRegressionModel} import org.apache.spark.ml.{Pipeline, PipelineModel} import org.apache.spark.sql.DataFrame @@ -32,10 +33,38 @@ private[r] object SparkRWrappers { alpha: Double): PipelineModel = { val formula = new RFormula().setFormula(value) val estimator = family match { - case "gaussian" => new LinearRegression().setRegParam(lambda).setElasticNetParam(alpha) - case "binomial" => new LogisticRegression().setRegParam(lambda).setElasticNetParam(alpha) + case "gaussian" => new LinearRegression() + .setRegParam(lambda) + .setElasticNetParam(alpha) + .setFitIntercept(formula.hasIntercept) + case "binomial" => new LogisticRegression() + .setRegParam(lambda) + .setElasticNetParam(alpha) + .setFitIntercept(formula.hasIntercept) } val pipeline = new Pipeline().setStages(Array(formula, estimator)) pipeline.fit(df) } + + def getModelWeights(model: PipelineModel): Array[Double] = { + model.stages.last match { + case m: LinearRegressionModel => + Array(m.intercept) ++ m.weights.toArray + case _: LogisticRegressionModel => + throw new UnsupportedOperationException( + "No weights available for LogisticRegressionModel") // SPARK-9492 + } + } + + def getModelFeatures(model: PipelineModel): Array[String] = { + model.stages.last match { + case m: LinearRegressionModel => + val attrs = AttributeGroup.fromStructField( + m.summary.predictions.schema(m.summary.featuresCol)) + Array("(Intercept)") ++ attrs.attributes.get.map(_.name.get) + case _: LogisticRegressionModel => + throw new UnsupportedOperationException( + "No features names available for LogisticRegressionModel") // SPARK-9492 + } + } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala new file mode 100644 index 0000000000000..4ece8cf8cf0b6 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala @@ -0,0 +1,144 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.regression + +import org.apache.spark.annotation.Experimental +import org.apache.spark.ml.PredictorParams +import org.apache.spark.ml.param.{Param, ParamMap, BooleanParam} +import org.apache.spark.ml.util.{SchemaUtils, Identifiable} +import org.apache.spark.mllib.regression.{IsotonicRegression => MLlibIsotonicRegression} +import org.apache.spark.mllib.regression.{IsotonicRegressionModel => MLlibIsotonicRegressionModel} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.types.{DoubleType, DataType} +import org.apache.spark.sql.{Row, DataFrame} +import org.apache.spark.storage.StorageLevel + +/** + * Params for isotonic regression. + */ +private[regression] trait IsotonicRegressionParams extends PredictorParams { + + /** + * Param for weight column name. + * TODO: Move weightCol to sharedParams. + * + * @group param + */ + final val weightCol: Param[String] = + new Param[String](this, "weightCol", "weight column name") + + /** @group getParam */ + final def getWeightCol: String = $(weightCol) + + /** + * Param for isotonic parameter. + * Isotonic (increasing) or antitonic (decreasing) sequence. + * @group param + */ + final val isotonic: BooleanParam = + new BooleanParam(this, "isotonic", "isotonic (increasing) or antitonic (decreasing) sequence") + + /** @group getParam */ + final def getIsotonicParam: Boolean = $(isotonic) +} + +/** + * :: Experimental :: + * Isotonic regression. + * + * Currently implemented using parallelized pool adjacent violators algorithm. + * Only univariate (single feature) algorithm supported. + * + * Uses [[org.apache.spark.mllib.regression.IsotonicRegression]]. + */ +@Experimental +class IsotonicRegression(override val uid: String) + extends Regressor[Double, IsotonicRegression, IsotonicRegressionModel] + with IsotonicRegressionParams { + + def this() = this(Identifiable.randomUID("isoReg")) + + /** + * Set the isotonic parameter. + * Default is true. + * @group setParam + */ + def setIsotonicParam(value: Boolean): this.type = set(isotonic, value) + setDefault(isotonic -> true) + + /** + * Set weight column param. + * Default is weight. + * @group setParam + */ + def setWeightParam(value: String): this.type = set(weightCol, value) + setDefault(weightCol -> "weight") + + override private[ml] def featuresDataType: DataType = DoubleType + + override def copy(extra: ParamMap): IsotonicRegression = defaultCopy(extra) + + private[this] def extractWeightedLabeledPoints( + dataset: DataFrame): RDD[(Double, Double, Double)] = { + + dataset.select($(labelCol), $(featuresCol), $(weightCol)) + .map { case Row(label: Double, features: Double, weights: Double) => + (label, features, weights) + } + } + + override protected def train(dataset: DataFrame): IsotonicRegressionModel = { + SchemaUtils.checkColumnType(dataset.schema, $(weightCol), DoubleType) + // Extract columns from data. If dataset is persisted, do not persist oldDataset. + val instances = extractWeightedLabeledPoints(dataset) + val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE + if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK) + + val isotonicRegression = new MLlibIsotonicRegression().setIsotonic($(isotonic)) + val parentModel = isotonicRegression.run(instances) + + new IsotonicRegressionModel(uid, parentModel) + } +} + +/** + * :: Experimental :: + * Model fitted by IsotonicRegression. + * Predicts using a piecewise linear function. + * + * For detailed rules see [[org.apache.spark.mllib.regression.IsotonicRegressionModel.predict()]]. + * + * @param parentModel A [[org.apache.spark.mllib.regression.IsotonicRegressionModel]] + * model trained by [[org.apache.spark.mllib.regression.IsotonicRegression]]. + */ +class IsotonicRegressionModel private[ml] ( + override val uid: String, + private[ml] val parentModel: MLlibIsotonicRegressionModel) + extends RegressionModel[Double, IsotonicRegressionModel] + with IsotonicRegressionParams { + + override def featuresDataType: DataType = DoubleType + + override protected def predict(features: Double): Double = { + parentModel.predict(features) + } + + override def copy(extra: ParamMap): IsotonicRegressionModel = { + copyValues(new IsotonicRegressionModel(uid, parentModel), extra) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 89718e0f3e15a..3b85ba001b128 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -36,6 +36,7 @@ import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.functions.{col, udf} +import org.apache.spark.sql.types.StructField import org.apache.spark.storage.StorageLevel import org.apache.spark.util.StatCounter @@ -146,9 +147,10 @@ class LinearRegression(override val uid: String) val model = new LinearRegressionModel(uid, weights, intercept) val trainingSummary = new LinearRegressionTrainingSummary( - model.transform(dataset).select($(predictionCol), $(labelCol)), + model.transform(dataset), $(predictionCol), $(labelCol), + $(featuresCol), Array(0D)) return copyValues(model.setSummary(trainingSummary)) } @@ -221,9 +223,10 @@ class LinearRegression(override val uid: String) val model = copyValues(new LinearRegressionModel(uid, weights, intercept)) val trainingSummary = new LinearRegressionTrainingSummary( - model.transform(dataset).select($(predictionCol), $(labelCol)), + model.transform(dataset), $(predictionCol), $(labelCol), + $(featuresCol), objectiveHistory) model.setSummary(trainingSummary) } @@ -300,6 +303,7 @@ class LinearRegressionTrainingSummary private[regression] ( predictions: DataFrame, predictionCol: String, labelCol: String, + val featuresCol: String, val objectiveHistory: Array[Double]) extends LinearRegressionSummary(predictions, predictionCol, labelCol) { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/GaussianMixtureModelWrapper.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/GaussianMixtureModelWrapper.scala new file mode 100644 index 0000000000000..0ec88ef77d695 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/GaussianMixtureModelWrapper.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.api.python + +import java.util.{List => JList} + +import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.SparkContext +import org.apache.spark.mllib.linalg.{Vector, Vectors, Matrix} +import org.apache.spark.mllib.clustering.GaussianMixtureModel + +/** + * Wrapper around GaussianMixtureModel to provide helper methods in Python + */ +private[python] class GaussianMixtureModelWrapper(model: GaussianMixtureModel) { + val weights: Vector = Vectors.dense(model.weights) + val k: Int = weights.size + + /** + * Returns gaussians as a List of Vectors and Matrices corresponding each MultivariateGaussian + */ + val gaussians: JList[Object] = { + val modelGaussians = model.gaussians + var i = 0 + var mu = ArrayBuffer.empty[Vector] + var sigma = ArrayBuffer.empty[Matrix] + while (i < k) { + mu += modelGaussians(i).mu + sigma += modelGaussians(i).sigma + i += 1 + } + List(mu.toArray, sigma.toArray).map(_.asInstanceOf[Object]).asJava + } + + def save(sc: SparkContext, path: String): Unit = model.save(sc, path) +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index fda8d5a0b048f..6f080d32bbf4d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -364,7 +364,7 @@ private[python] class PythonMLLibAPI extends Serializable { seed: java.lang.Long, initialModelWeights: java.util.ArrayList[Double], initialModelMu: java.util.ArrayList[Vector], - initialModelSigma: java.util.ArrayList[Matrix]): JList[Object] = { + initialModelSigma: java.util.ArrayList[Matrix]): GaussianMixtureModelWrapper = { val gmmAlg = new GaussianMixture() .setK(k) .setConvergenceTol(convergenceTol) @@ -382,16 +382,7 @@ private[python] class PythonMLLibAPI extends Serializable { if (seed != null) gmmAlg.setSeed(seed) try { - val model = gmmAlg.run(data.rdd.persist(StorageLevel.MEMORY_AND_DISK)) - var wt = ArrayBuffer.empty[Double] - var mu = ArrayBuffer.empty[Vector] - var sigma = ArrayBuffer.empty[Matrix] - for (i <- 0 until model.k) { - wt += model.weights(i) - mu += model.gaussians(i).mu - sigma += model.gaussians(i).sigma - } - List(Vectors.dense(wt.toArray), mu.toArray, sigma.toArray).map(_.asInstanceOf[Object]).asJava + new GaussianMixtureModelWrapper(gmmAlg.run(data.rdd.persist(StorageLevel.MEMORY_AND_DISK))) } finally { data.rdd.unpersist(blocking = false) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala index 31c1d520fd659..6cfad3fbbdb87 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala @@ -17,10 +17,9 @@ package org.apache.spark.mllib.clustering -import breeze.linalg.{DenseMatrix => BDM, normalize, sum => brzSum, DenseVector => BDV} - +import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, normalize, sum} +import breeze.numerics.{exp, lgamma} import org.apache.hadoop.fs.Path - import org.json4s.DefaultFormats import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ @@ -28,14 +27,13 @@ import org.json4s.jackson.JsonMethods._ import org.apache.spark.SparkContext import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaPairRDD -import org.apache.spark.graphx.{VertexId, Edge, EdgeContext, Graph} -import org.apache.spark.mllib.linalg.{Vectors, Vector, Matrices, Matrix, DenseVector} -import org.apache.spark.mllib.util.{Saveable, Loader} +import org.apache.spark.graphx.{Edge, EdgeContext, Graph, VertexId} +import org.apache.spark.mllib.linalg.{Matrices, Matrix, Vector, Vectors} +import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{SQLContext, Row} +import org.apache.spark.sql.{Row, SQLContext} import org.apache.spark.util.BoundedPriorityQueue - /** * :: Experimental :: * @@ -53,6 +51,31 @@ abstract class LDAModel private[clustering] extends Saveable { /** Vocabulary size (number of terms or terms in the vocabulary) */ def vocabSize: Int + /** + * Concentration parameter (commonly named "alpha") for the prior placed on documents' + * distributions over topics ("theta"). + * + * This is the parameter to a Dirichlet distribution. + */ + def docConcentration: Vector + + /** + * Concentration parameter (commonly named "beta" or "eta") for the prior placed on topics' + * distributions over terms. + * + * This is the parameter to a symmetric Dirichlet distribution. + * + * Note: The topics' distributions over terms are called "beta" in the original LDA paper + * by Blei et al., but are called "phi" in many later papers such as Asuncion et al., 2009. + */ + def topicConcentration: Double + + /** + * Shape parameter for random initialization of variational parameter gamma. + * Used for variational inference for perplexity and other test-time computations. + */ + protected def gammaShape: Double + /** * Inferred topics, where each topic is represented by a distribution over terms. * This is a matrix of size vocabSize x k, where each column is a topic. @@ -163,12 +186,14 @@ abstract class LDAModel private[clustering] extends Saveable { * This model stores only the inferred topics. * It may be used for computing topics for new documents, but it may give less accurate answers * than the [[DistributedLDAModel]]. - * * @param topics Inferred topics (vocabSize x k matrix). */ @Experimental class LocalLDAModel private[clustering] ( - private val topics: Matrix) extends LDAModel with Serializable { + val topics: Matrix, + override val docConcentration: Vector, + override val topicConcentration: Double, + override protected[clustering] val gammaShape: Double) extends LDAModel with Serializable { override def k: Int = topics.numCols @@ -189,16 +214,122 @@ class LocalLDAModel private[clustering] ( override protected def formatVersion = "1.0" override def save(sc: SparkContext, path: String): Unit = { - LocalLDAModel.SaveLoadV1_0.save(sc, path, topicsMatrix) + LocalLDAModel.SaveLoadV1_0.save(sc, path, topicsMatrix, docConcentration, topicConcentration, + gammaShape) } // TODO // override def logLikelihood(documents: RDD[(Long, Vector)]): Double = ??? - // TODO: - // override def topicDistributions(documents: RDD[(Long, Vector)]): RDD[(Long, Vector)] = ??? + /** + * Calculate the log variational bound on perplexity. See Equation (16) in original Online + * LDA paper. + * @param documents test corpus to use for calculating perplexity + * @return the log perplexity per word + */ + def logPerplexity(documents: RDD[(Long, Vector)]): Double = { + val corpusWords = documents + .map { case (_, termCounts) => termCounts.toArray.sum } + .sum() + val batchVariationalBound = bound(documents, docConcentration, + topicConcentration, topicsMatrix.toBreeze.toDenseMatrix, gammaShape, k, vocabSize) + val perWordBound = batchVariationalBound / corpusWords + + perWordBound + } + + /** + * Estimate the variational likelihood bound of from `documents`: + * log p(documents) >= E_q[log p(documents)] - E_q[log q(documents)] + * This bound is derived by decomposing the LDA model to: + * log p(documents) = E_q[log p(documents)] - E_q[log q(documents)] + D(q|p) + * and noting that the KL-divergence D(q|p) >= 0. See Equation (16) in original Online LDA paper. + * @param documents a subset of the test corpus + * @param alpha document-topic Dirichlet prior parameters + * @param eta topic-word Dirichlet prior parameters + * @param lambda parameters for variational q(beta | lambda) topic-word distributions + * @param gammaShape shape parameter for random initialization of variational q(theta | gamma) + * topic mixture distributions + * @param k number of topics + * @param vocabSize number of unique terms in the entire test corpus + */ + private def bound( + documents: RDD[(Long, Vector)], + alpha: Vector, + eta: Double, + lambda: BDM[Double], + gammaShape: Double, + k: Int, + vocabSize: Long): Double = { + val brzAlpha = alpha.toBreeze.toDenseVector + // transpose because dirichletExpectation normalizes by row and we need to normalize + // by topic (columns of lambda) + val Elogbeta = LDAUtils.dirichletExpectation(lambda.t).t + + var score = documents.filter(_._2.numNonzeros > 0).map { case (id: Long, termCounts: Vector) => + var docScore = 0.0D + val (gammad: BDV[Double], _) = OnlineLDAOptimizer.variationalTopicInference( + termCounts, exp(Elogbeta), brzAlpha, gammaShape, k) + val Elogthetad: BDV[Double] = LDAUtils.dirichletExpectation(gammad) + + // E[log p(doc | theta, beta)] + termCounts.foreachActive { case (idx, count) => + docScore += count * LDAUtils.logSumExp(Elogthetad + Elogbeta(idx, ::).t) + } + // E[log p(theta | alpha) - log q(theta | gamma)]; assumes alpha is a vector + docScore += sum((brzAlpha - gammad) :* Elogthetad) + docScore += sum(lgamma(gammad) - lgamma(brzAlpha)) + docScore += lgamma(sum(brzAlpha)) - lgamma(sum(gammad)) + + docScore + }.sum() + + // E[log p(beta | eta) - log q (beta | lambda)]; assumes eta is a scalar + score += sum((eta - lambda) :* Elogbeta) + score += sum(lgamma(lambda) - lgamma(eta)) + + val sumEta = eta * vocabSize + score += sum(lgamma(sumEta) - lgamma(sum(lambda(::, breeze.linalg.*)))) + + score + } + + /** + * Predicts the topic mixture distribution for each document (often called "theta" in the + * literature). Returns a vector of zeros for an empty document. + * + * This uses a variational approximation following Hoffman et al. (2010), where the approximate + * distribution is called "gamma." Technically, this method returns this approximation "gamma" + * for each document. + * @param documents documents to predict topic mixture distributions for + * @return An RDD of (document ID, topic mixture distribution for document) + */ + // TODO: declare in LDAModel and override once implemented in DistributedLDAModel + def topicDistributions(documents: RDD[(Long, Vector)]): RDD[(Long, Vector)] = { + // Double transpose because dirichletExpectation normalizes by row and we need to normalize + // by topic (columns of lambda) + val expElogbeta = exp(LDAUtils.dirichletExpectation(topicsMatrix.toBreeze.toDenseMatrix.t).t) + val docConcentrationBrz = this.docConcentration.toBreeze + val gammaShape = this.gammaShape + val k = this.k + + documents.map { case (id: Long, termCounts: Vector) => + if (termCounts.numNonzeros == 0) { + (id, Vectors.zeros(k)) + } else { + val (gamma, _) = OnlineLDAOptimizer.variationalTopicInference( + termCounts, + expElogbeta, + docConcentrationBrz, + gammaShape, + k) + (id, Vectors.dense(normalize(gamma, 1.0).toArray)) + } + } + } } + @Experimental object LocalLDAModel extends Loader[LocalLDAModel] { @@ -212,14 +343,23 @@ object LocalLDAModel extends Loader[LocalLDAModel] { // as a Row in data. case class Data(topic: Vector, index: Int) - def save(sc: SparkContext, path: String, topicsMatrix: Matrix): Unit = { + def save( + sc: SparkContext, + path: String, + topicsMatrix: Matrix, + docConcentration: Vector, + topicConcentration: Double, + gammaShape: Double): Unit = { val sqlContext = SQLContext.getOrCreate(sc) import sqlContext.implicits._ val k = topicsMatrix.numCols val metadata = compact(render (("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ - ("k" -> k) ~ ("vocabSize" -> topicsMatrix.numRows))) + ("k" -> k) ~ ("vocabSize" -> topicsMatrix.numRows) ~ + ("docConcentration" -> docConcentration.toArray.toSeq) ~ + ("topicConcentration" -> topicConcentration) ~ + ("gammaShape" -> gammaShape))) sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path)) val topicsDenseMatrix = topicsMatrix.toBreeze.toDenseMatrix @@ -229,7 +369,12 @@ object LocalLDAModel extends Loader[LocalLDAModel] { sc.parallelize(topics, 1).toDF().write.parquet(Loader.dataPath(path)) } - def load(sc: SparkContext, path: String): LocalLDAModel = { + def load( + sc: SparkContext, + path: String, + docConcentration: Vector, + topicConcentration: Double, + gammaShape: Double): LocalLDAModel = { val dataPath = Loader.dataPath(path) val sqlContext = SQLContext.getOrCreate(sc) val dataFrame = sqlContext.read.parquet(dataPath) @@ -243,7 +388,10 @@ object LocalLDAModel extends Loader[LocalLDAModel] { topics.foreach { case Row(vec: Vector, ind: Int) => brzTopics(::, ind) := vec.toBreeze } - new LocalLDAModel(Matrices.fromBreeze(brzTopics)) + val topicsMat = Matrices.fromBreeze(brzTopics) + + // TODO: initialize with docConcentration, topicConcentration, and gammaShape after SPARK-9940 + new LocalLDAModel(topicsMat, docConcentration, topicConcentration, gammaShape) } } @@ -252,15 +400,19 @@ object LocalLDAModel extends Loader[LocalLDAModel] { implicit val formats = DefaultFormats val expectedK = (metadata \ "k").extract[Int] val expectedVocabSize = (metadata \ "vocabSize").extract[Int] + val docConcentration = + Vectors.dense((metadata \ "docConcentration").extract[Seq[Double]].toArray) + val topicConcentration = (metadata \ "topicConcentration").extract[Double] + val gammaShape = (metadata \ "gammaShape").extract[Double] val classNameV1_0 = SaveLoadV1_0.thisClassName val model = (loadedClassName, loadedVersion) match { case (className, "1.0") if className == classNameV1_0 => - SaveLoadV1_0.load(sc, path) + SaveLoadV1_0.load(sc, path, docConcentration, topicConcentration, gammaShape) case _ => throw new Exception( s"LocalLDAModel.load did not recognize model with (className, format version):" + - s"($loadedClassName, $loadedVersion). Supported:\n" + - s" ($classNameV1_0, 1.0)") + s"($loadedClassName, $loadedVersion). Supported:\n" + + s" ($classNameV1_0, 1.0)") } val topicsMatrix = model.topicsMatrix @@ -268,7 +420,7 @@ object LocalLDAModel extends Loader[LocalLDAModel] { s"LocalLDAModel requires $expectedK topics, got ${topicsMatrix.numCols} topics") require(expectedVocabSize == topicsMatrix.numRows, s"LocalLDAModel requires $expectedVocabSize terms for each topic, " + - s"but got ${topicsMatrix.numRows}") + s"but got ${topicsMatrix.numRows}") model } } @@ -282,28 +434,25 @@ object LocalLDAModel extends Loader[LocalLDAModel] { * than the [[LocalLDAModel]]. */ @Experimental -class DistributedLDAModel private ( +class DistributedLDAModel private[clustering] ( private[clustering] val graph: Graph[LDA.TopicCounts, LDA.TokenCount], private[clustering] val globalTopicTotals: LDA.TopicCounts, val k: Int, val vocabSize: Int, - private[clustering] val docConcentration: Double, - private[clustering] val topicConcentration: Double, + override val docConcentration: Vector, + override val topicConcentration: Double, + override protected[clustering] val gammaShape: Double, private[spark] val iterationTimes: Array[Double]) extends LDAModel { import LDA._ - private[clustering] def this(state: EMLDAOptimizer, iterationTimes: Array[Double]) = { - this(state.graph, state.globalTopicTotals, state.k, state.vocabSize, state.docConcentration, - state.topicConcentration, iterationTimes) - } - /** * Convert model to a local model. * The local model stores the inferred topics but not the topic distributions for training * documents. */ - def toLocal: LocalLDAModel = new LocalLDAModel(topicsMatrix) + def toLocal: LocalLDAModel = new LocalLDAModel(topicsMatrix, docConcentration, topicConcentration, + gammaShape) /** * Inferred topics, where each topic is represented by a distribution over terms. @@ -375,8 +524,9 @@ class DistributedLDAModel private ( * hyperparameters. */ lazy val logLikelihood: Double = { - val eta = topicConcentration - val alpha = docConcentration + // TODO: generalize this for asymmetric (non-scalar) alpha + val alpha = this.docConcentration(0) // To avoid closure capture of enclosing object + val eta = this.topicConcentration assert(eta > 1.0) assert(alpha > 1.0) val N_k = globalTopicTotals @@ -400,8 +550,9 @@ class DistributedLDAModel private ( * log P(topics, topic distributions for docs | alpha, eta) */ lazy val logPrior: Double = { - val eta = topicConcentration - val alpha = docConcentration + // TODO: generalize this for asymmetric (non-scalar) alpha + val alpha = this.docConcentration(0) // To avoid closure capture of enclosing object + val eta = this.topicConcentration // Term vertices: Compute phi_{wk}. Use to compute prior log probability. // Doc vertex: Compute theta_{kj}. Use to compute prior log probability. val N_k = globalTopicTotals @@ -412,12 +563,12 @@ class DistributedLDAModel private ( val N_wk = vertex._2 val smoothed_N_wk: TopicCounts = N_wk + (eta - 1.0) val phi_wk: TopicCounts = smoothed_N_wk :/ smoothed_N_k - (eta - 1.0) * brzSum(phi_wk.map(math.log)) + (eta - 1.0) * sum(phi_wk.map(math.log)) } else { val N_kj = vertex._2 val smoothed_N_kj: TopicCounts = N_kj + (alpha - 1.0) val theta_kj: TopicCounts = normalize(smoothed_N_kj, 1.0) - (alpha - 1.0) * brzSum(theta_kj.map(math.log)) + (alpha - 1.0) * sum(theta_kj.map(math.log)) } } graph.vertices.aggregate(0.0)(seqOp, _ + _) @@ -448,7 +599,7 @@ class DistributedLDAModel private ( override def save(sc: SparkContext, path: String): Unit = { DistributedLDAModel.SaveLoadV1_0.save( sc, path, graph, globalTopicTotals, k, vocabSize, docConcentration, topicConcentration, - iterationTimes) + iterationTimes, gammaShape) } } @@ -460,7 +611,7 @@ object DistributedLDAModel extends Loader[DistributedLDAModel] { val thisFormatVersion = "1.0" - val classNameV1_0 = "org.apache.spark.mllib.clustering.DistributedLDAModel" + val thisClassName = "org.apache.spark.mllib.clustering.DistributedLDAModel" // Store globalTopicTotals as a Vector. case class Data(globalTopicTotals: Vector) @@ -478,17 +629,20 @@ object DistributedLDAModel extends Loader[DistributedLDAModel] { globalTopicTotals: LDA.TopicCounts, k: Int, vocabSize: Int, - docConcentration: Double, + docConcentration: Vector, topicConcentration: Double, - iterationTimes: Array[Double]): Unit = { + iterationTimes: Array[Double], + gammaShape: Double): Unit = { val sqlContext = SQLContext.getOrCreate(sc) import sqlContext.implicits._ val metadata = compact(render - (("class" -> classNameV1_0) ~ ("version" -> thisFormatVersion) ~ - ("k" -> k) ~ ("vocabSize" -> vocabSize) ~ ("docConcentration" -> docConcentration) ~ - ("topicConcentration" -> topicConcentration) ~ - ("iterationTimes" -> iterationTimes.toSeq))) + (("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ + ("k" -> k) ~ ("vocabSize" -> vocabSize) ~ + ("docConcentration" -> docConcentration.toArray.toSeq) ~ + ("topicConcentration" -> topicConcentration) ~ + ("iterationTimes" -> iterationTimes.toSeq) ~ + ("gammaShape" -> gammaShape))) sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path)) val newPath = new Path(Loader.dataPath(path), "globalTopicTotals").toUri.toString @@ -510,9 +664,10 @@ object DistributedLDAModel extends Loader[DistributedLDAModel] { sc: SparkContext, path: String, vocabSize: Int, - docConcentration: Double, + docConcentration: Vector, topicConcentration: Double, - iterationTimes: Array[Double]): DistributedLDAModel = { + iterationTimes: Array[Double], + gammaShape: Double): DistributedLDAModel = { val dataPath = new Path(Loader.dataPath(path), "globalTopicTotals").toUri.toString val vertexDataPath = new Path(Loader.dataPath(path), "topicCounts").toUri.toString val edgeDataPath = new Path(Loader.dataPath(path), "tokenCounts").toUri.toString @@ -536,7 +691,7 @@ object DistributedLDAModel extends Loader[DistributedLDAModel] { val graph: Graph[LDA.TopicCounts, LDA.TokenCount] = Graph(vertices, edges) new DistributedLDAModel(graph, globalTopicTotals, globalTopicTotals.length, vocabSize, - docConcentration, topicConcentration, iterationTimes) + docConcentration, topicConcentration, gammaShape, iterationTimes) } } @@ -546,32 +701,35 @@ object DistributedLDAModel extends Loader[DistributedLDAModel] { implicit val formats = DefaultFormats val expectedK = (metadata \ "k").extract[Int] val vocabSize = (metadata \ "vocabSize").extract[Int] - val docConcentration = (metadata \ "docConcentration").extract[Double] + val docConcentration = + Vectors.dense((metadata \ "docConcentration").extract[Seq[Double]].toArray) val topicConcentration = (metadata \ "topicConcentration").extract[Double] val iterationTimes = (metadata \ "iterationTimes").extract[Seq[Double]] - val classNameV1_0 = SaveLoadV1_0.classNameV1_0 + val gammaShape = (metadata \ "gammaShape").extract[Double] + val classNameV1_0 = SaveLoadV1_0.thisClassName val model = (loadedClassName, loadedVersion) match { case (className, "1.0") if className == classNameV1_0 => { - DistributedLDAModel.SaveLoadV1_0.load( - sc, path, vocabSize, docConcentration, topicConcentration, iterationTimes.toArray) + DistributedLDAModel.SaveLoadV1_0.load(sc, path, vocabSize, docConcentration, + topicConcentration, iterationTimes.toArray, gammaShape) } case _ => throw new Exception( s"DistributedLDAModel.load did not recognize model with (className, format version):" + - s"($loadedClassName, $loadedVersion). Supported: ($classNameV1_0, 1.0)") + s"($loadedClassName, $loadedVersion). Supported: ($classNameV1_0, 1.0)") } require(model.vocabSize == vocabSize, s"DistributedLDAModel requires $vocabSize vocabSize, got ${model.vocabSize} vocabSize") require(model.docConcentration == docConcentration, s"DistributedLDAModel requires $docConcentration docConcentration, " + - s"got ${model.docConcentration} docConcentration") + s"got ${model.docConcentration} docConcentration") require(model.topicConcentration == topicConcentration, s"DistributedLDAModel requires $topicConcentration docConcentration, " + - s"got ${model.topicConcentration} docConcentration") + s"got ${model.topicConcentration} docConcentration") require(expectedK == model.k, s"DistributedLDAModel requires $expectedK topics, got ${model.k} topics") model } } + diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala index f4170a3d98dd8..d6f8b29a43dfd 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala @@ -20,7 +20,7 @@ package org.apache.spark.mllib.clustering import java.util.Random import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, normalize, sum} -import breeze.numerics.{abs, digamma, exp} +import breeze.numerics.{abs, exp} import breeze.stats.distributions.{Gamma, RandBasis} import org.apache.spark.annotation.DeveloperApi @@ -142,8 +142,9 @@ final class EMLDAOptimizer extends LDAOptimizer { this.k = k this.vocabSize = docs.take(1).head._2.size this.checkpointInterval = lda.getCheckpointInterval - this.graphCheckpointer = new - PeriodicGraphCheckpointer[TopicCounts, TokenCount](graph, checkpointInterval) + this.graphCheckpointer = new PeriodicGraphCheckpointer[TopicCounts, TokenCount]( + checkpointInterval, graph.vertices.sparkContext) + this.graphCheckpointer.update(this.graph) this.globalTopicTotals = computeGlobalTopicTotals() this } @@ -188,7 +189,7 @@ final class EMLDAOptimizer extends LDAOptimizer { // Update the vertex descriptors with the new counts. val newGraph = GraphImpl.fromExistingRDDs(docTopicDistributions, graph.edges) graph = newGraph - graphCheckpointer.updateGraph(newGraph) + graphCheckpointer.update(newGraph) globalTopicTotals = computeGlobalTopicTotals() this } @@ -208,7 +209,11 @@ final class EMLDAOptimizer extends LDAOptimizer { override private[clustering] def getLDAModel(iterationTimes: Array[Double]): LDAModel = { require(graph != null, "graph is null, EMLDAOptimizer not initialized.") this.graphCheckpointer.deleteAllCheckpoints() - new DistributedLDAModel(this, iterationTimes) + // This assumes gammaShape = 100 in OnlineLDAOptimizer to ensure equivalence in LDAModel.toLocal + // conversion + new DistributedLDAModel(this.graph, this.globalTopicTotals, this.k, this.vocabSize, + Vectors.dense(Array.fill(this.k)(this.docConcentration)), this.topicConcentration, + 100, iterationTimes) } } @@ -385,71 +390,52 @@ final class OnlineLDAOptimizer extends LDAOptimizer { iteration += 1 val k = this.k val vocabSize = this.vocabSize - val Elogbeta = dirichletExpectation(lambda).t - val expElogbeta = exp(Elogbeta) + val expElogbeta = exp(LDAUtils.dirichletExpectation(lambda)).t val alpha = this.alpha.toBreeze val gammaShape = this.gammaShape - val stats: RDD[BDM[Double]] = batch.mapPartitions { docs => - val stat = BDM.zeros[Double](k, vocabSize) - docs.foreach { doc => - val termCounts = doc._2 - val (ids: List[Int], cts: Array[Double]) = termCounts match { - case v: DenseVector => ((0 until v.size).toList, v.values) - case v: SparseVector => (v.indices.toList, v.values) - case v => throw new IllegalArgumentException("Online LDA does not support vector type " - + v.getClass) - } - if (!ids.isEmpty) { - - // Initialize the variational distribution q(theta|gamma) for the mini-batch - val gammad: BDV[Double] = - new Gamma(gammaShape, 1.0 / gammaShape).samplesVector(k) // K - val expElogthetad: BDV[Double] = exp(digamma(gammad) - digamma(sum(gammad))) // K - val expElogbetad: BDM[Double] = expElogbeta(ids, ::).toDenseMatrix // ids * K - - val phinorm: BDV[Double] = expElogbetad * expElogthetad :+ 1e-100 // ids - var meanchange = 1D - val ctsVector = new BDV[Double](cts) // ids - - // Iterate between gamma and phi until convergence - while (meanchange > 1e-3) { - val lastgamma = gammad.copy - // K K * ids ids - gammad := (expElogthetad :* (expElogbetad.t * (ctsVector :/ phinorm))) :+ alpha - expElogthetad := exp(digamma(gammad) - digamma(sum(gammad))) - phinorm := expElogbetad * expElogthetad :+ 1e-100 - meanchange = sum(abs(gammad - lastgamma)) / k - } + val stats: RDD[(BDM[Double], List[BDV[Double]])] = batch.mapPartitions { docs => + val nonEmptyDocs = docs.filter(_._2.numNonzeros > 0) - stat(::, ids) := expElogthetad.asDenseMatrix.t * (ctsVector :/ phinorm).asDenseMatrix + val stat = BDM.zeros[Double](k, vocabSize) + var gammaPart = List[BDV[Double]]() + nonEmptyDocs.zipWithIndex.foreach { case ((_, termCounts: Vector), idx: Int) => + val ids: List[Int] = termCounts match { + case v: DenseVector => (0 until v.size).toList + case v: SparseVector => v.indices.toList } + val (gammad, sstats) = OnlineLDAOptimizer.variationalTopicInference( + termCounts, expElogbeta, alpha, gammaShape, k) + stat(::, ids) := stat(::, ids).toDenseMatrix + sstats + gammaPart = gammad :: gammaPart } - Iterator(stat) + Iterator((stat, gammaPart)) } - - val statsSum: BDM[Double] = stats.reduce(_ += _) + val statsSum: BDM[Double] = stats.map(_._1).reduce(_ += _) + val gammat: BDM[Double] = breeze.linalg.DenseMatrix.vertcat( + stats.map(_._2).reduce(_ ++ _).map(_.toDenseMatrix): _*) val batchResult = statsSum :* expElogbeta.t // Note that this is an optimization to avoid batch.count - update(batchResult, iteration, (miniBatchFraction * corpusSize).ceil.toInt) + updateLambda(batchResult, (miniBatchFraction * corpusSize).ceil.toInt) this } - override private[clustering] def getLDAModel(iterationTimes: Array[Double]): LDAModel = { - new LocalLDAModel(Matrices.fromBreeze(lambda).transpose) - } - /** * Update lambda based on the batch submitted. batchSize can be different for each iteration. */ - private[clustering] def update(stat: BDM[Double], iter: Int, batchSize: Int): Unit = { + private def updateLambda(stat: BDM[Double], batchSize: Int): Unit = { // weight of the mini-batch. - val weight = math.pow(getTau0 + iter, -getKappa) + val weight = rho() // Update lambda based on documents. - lambda = lambda * (1 - weight) + - (stat * (corpusSize.toDouble / batchSize.toDouble) + eta) * weight + lambda := (1 - weight) * lambda + + weight * (stat * (corpusSize.toDouble / batchSize.toDouble) + eta) + } + + /** Calculates learning rate rho, which decays as a function of [[iteration]] */ + private def rho(): Double = { + math.pow(getTau0 + this.iteration, -getKappa) } /** @@ -463,15 +449,57 @@ final class OnlineLDAOptimizer extends LDAOptimizer { new BDM[Double](col, row, temp).t } + override private[clustering] def getLDAModel(iterationTimes: Array[Double]): LDAModel = { + new LocalLDAModel(Matrices.fromBreeze(lambda).transpose, alpha, eta, gammaShape) + } + +} + +/** + * Serializable companion object containing helper methods and shared code for + * [[OnlineLDAOptimizer]] and [[LocalLDAModel]]. + */ +private[clustering] object OnlineLDAOptimizer { /** - * For theta ~ Dir(alpha), computes E[log(theta)] given alpha. Currently the implementation - * uses digamma which is accurate but expensive. + * Uses variational inference to infer the topic distribution `gammad` given the term counts + * for a document. `termCounts` must contain at least one non-zero entry, otherwise Breeze will + * throw a BLAS error. + * + * An optimization (Lee, Seung: Algorithms for non-negative matrix factorization, NIPS 2001) + * avoids explicit computation of variational parameter `phi`. + * @see [[http://citeseerx.ist.psu.edu/viewdoc/summary?doi=10.1.1.31.7566]] */ - private def dirichletExpectation(alpha: BDM[Double]): BDM[Double] = { - val rowSum = sum(alpha(breeze.linalg.*, ::)) - val digAlpha = digamma(alpha) - val digRowSum = digamma(rowSum) - val result = digAlpha(::, breeze.linalg.*) - digRowSum - result + private[clustering] def variationalTopicInference( + termCounts: Vector, + expElogbeta: BDM[Double], + alpha: breeze.linalg.Vector[Double], + gammaShape: Double, + k: Int): (BDV[Double], BDM[Double]) = { + val (ids: List[Int], cts: Array[Double]) = termCounts match { + case v: DenseVector => ((0 until v.size).toList, v.values) + case v: SparseVector => (v.indices.toList, v.values) + } + // Initialize the variational distribution q(theta|gamma) for the mini-batch + val gammad: BDV[Double] = + new Gamma(gammaShape, 1.0 / gammaShape).samplesVector(k) // K + val expElogthetad: BDV[Double] = exp(LDAUtils.dirichletExpectation(gammad)) // K + val expElogbetad = expElogbeta(ids, ::).toDenseMatrix // ids * K + + val phinorm: BDV[Double] = expElogbetad * expElogthetad :+ 1e-100 // ids + var meanchange = 1D + val ctsVector = new BDV[Double](cts) // ids + + // Iterate between gamma and phi until convergence + while (meanchange > 1e-3) { + val lastgamma = gammad.copy + // K K * ids ids + gammad := (expElogthetad :* (expElogbetad.t * (ctsVector :/ phinorm))) :+ alpha + expElogthetad := exp(LDAUtils.dirichletExpectation(gammad)) + phinorm := expElogbetad * expElogthetad :+ 1e-100 + meanchange = sum(abs(gammad - lastgamma)) / k + } + + val sstatsd = expElogthetad.asDenseMatrix.t * (ctsVector :/ phinorm).asDenseMatrix + (gammad, sstatsd) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAUtils.scala new file mode 100644 index 0000000000000..f7e5ce1665fe6 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAUtils.scala @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.mllib.clustering + +import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, max, sum} +import breeze.numerics._ + +/** + * Utility methods for LDA. + */ +object LDAUtils { + /** + * Log Sum Exp with overflow protection using the identity: + * For any a: \log \sum_{n=1}^N \exp\{x_n\} = a + \log \sum_{n=1}^N \exp\{x_n - a\} + */ + private[clustering] def logSumExp(x: BDV[Double]): Double = { + val a = max(x) + a + log(sum(exp(x :- a))) + } + + /** + * For theta ~ Dir(alpha), computes E[log(theta)] given alpha. Currently the implementation + * uses [[breeze.numerics.digamma]] which is accurate but expensive. + */ + private[clustering] def dirichletExpectation(alpha: BDV[Double]): BDV[Double] = { + digamma(alpha) - digamma(sum(alpha)) + } + + /** + * Computes [[dirichletExpectation()]] row-wise, assuming each row of alpha are + * Dirichlet parameters. + */ + private[clustering] def dirichletExpectation(alpha: BDM[Double]): BDM[Double] = { + val rowSum = sum(alpha(breeze.linalg.*, ::)) + val digAlpha = digamma(alpha) + val digRowSum = digamma(rowSum) + val result = digAlpha(::, breeze.linalg.*) - digRowSum + result + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala index 7ead6327486cc..0ea792081086d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala @@ -40,7 +40,7 @@ private[fpm] object LocalPrefixSpan extends Logging with Serializable { minCount: Long, maxPatternLength: Int, prefixes: List[Int], - database: Array[Array[Int]]): Iterator[(List[Int], Long)] = { + database: Iterable[Array[Int]]): Iterator[(List[Int], Long)] = { if (prefixes.length == maxPatternLength || database.isEmpty) return Iterator.empty val frequentItemAndCounts = getFreqItemAndCounts(minCount, database) val filteredDatabase = database.map(x => x.filter(frequentItemAndCounts.contains)) @@ -67,7 +67,7 @@ private[fpm] object LocalPrefixSpan extends Logging with Serializable { } } - def project(database: Array[Array[Int]], prefix: Int): Array[Array[Int]] = { + def project(database: Iterable[Array[Int]], prefix: Int): Iterable[Array[Int]] = { database .map(getSuffix(prefix, _)) .filter(_.nonEmpty) @@ -81,7 +81,7 @@ private[fpm] object LocalPrefixSpan extends Logging with Serializable { */ private def getFreqItemAndCounts( minCount: Long, - database: Array[Array[Int]]): mutable.Map[Int, Long] = { + database: Iterable[Array[Int]]): mutable.Map[Int, Long] = { // TODO: use PrimitiveKeyOpenHashMap val counts = mutable.Map[Int, Long]().withDefaultValue(0L) database.foreach { sequence => diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala index 6f52db7b073ae..e6752332cdeeb 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala @@ -17,6 +17,8 @@ package org.apache.spark.mllib.fpm +import scala.collection.mutable.ArrayBuffer + import org.apache.spark.Logging import org.apache.spark.annotation.Experimental import org.apache.spark.rdd.RDD @@ -43,28 +45,45 @@ class PrefixSpan private ( private var minSupport: Double, private var maxPatternLength: Int) extends Logging with Serializable { + /** + * The maximum number of items allowed in a projected database before local processing. If a + * projected database exceeds this size, another iteration of distributed PrefixSpan is run. + */ + // TODO: make configurable with a better default value, 10000 may be too small + private val maxLocalProjDBSize: Long = 10000 + /** * Constructs a default instance with default parameters * {minSupport: `0.1`, maxPatternLength: `10`}. */ def this() = this(0.1, 10) + /** + * Get the minimal support (i.e. the frequency of occurrence before a pattern is considered + * frequent). + */ + def getMinSupport: Double = this.minSupport + /** * Sets the minimal support level (default: `0.1`). */ def setMinSupport(minSupport: Double): this.type = { - require(minSupport >= 0 && minSupport <= 1, - "The minimum support value must be between 0 and 1, including 0 and 1.") + require(minSupport >= 0 && minSupport <= 1, "The minimum support value must be in [0, 1].") this.minSupport = minSupport this } + /** + * Gets the maximal pattern length (i.e. the length of the longest sequential pattern to consider. + */ + def getMaxPatternLength: Double = this.maxPatternLength + /** * Sets maximal pattern length (default: `10`). */ def setMaxPatternLength(maxPatternLength: Int): this.type = { - require(maxPatternLength >= 1, - "The maximum pattern length value must be greater than 0.") + // TODO: support unbounded pattern length when maxPatternLength = 0 + require(maxPatternLength >= 1, "The maximum pattern length value must be greater than 0.") this.maxPatternLength = maxPatternLength this } @@ -78,81 +97,153 @@ class PrefixSpan private ( * the value of pair is the pattern's count. */ def run(sequences: RDD[Array[Int]]): RDD[(Array[Int], Long)] = { + val sc = sequences.sparkContext + if (sequences.getStorageLevel == StorageLevel.NONE) { logWarning("Input data is not cached.") } - val minCount = getMinCount(sequences) - val lengthOnePatternsAndCounts = - getFreqItemAndCounts(minCount, sequences).collect() - val prefixAndProjectedDatabase = getPrefixAndProjectedDatabase( - lengthOnePatternsAndCounts.map(_._1), sequences) - val groupedProjectedDatabase = prefixAndProjectedDatabase - .map(x => (x._1.toSeq, x._2)) - .groupByKey() - .map(x => (x._1.toArray, x._2.toArray)) - val nextPatterns = getPatternsInLocal(minCount, groupedProjectedDatabase) - val lengthOnePatternsAndCountsRdd = - sequences.sparkContext.parallelize( - lengthOnePatternsAndCounts.map(x => (Array(x._1), x._2))) - val allPatterns = lengthOnePatternsAndCountsRdd ++ nextPatterns - allPatterns + + // Convert min support to a min number of transactions for this dataset + val minCount = if (minSupport == 0) 0L else math.ceil(sequences.count() * minSupport).toLong + + // (Frequent items -> number of occurrences, all items here satisfy the `minSupport` threshold + val freqItemCounts = sequences + .flatMap(seq => seq.distinct.map(item => (item, 1L))) + .reduceByKey(_ + _) + .filter(_._2 >= minCount) + .collect() + + // Pairs of (length 1 prefix, suffix consisting of frequent items) + val itemSuffixPairs = { + val freqItems = freqItemCounts.map(_._1).toSet + sequences.flatMap { seq => + val filteredSeq = seq.filter(freqItems.contains(_)) + freqItems.flatMap { item => + val candidateSuffix = LocalPrefixSpan.getSuffix(item, filteredSeq) + candidateSuffix match { + case suffix if !suffix.isEmpty => Some((List(item), suffix)) + case _ => None + } + } + } + } + + // Accumulator for the computed results to be returned, initialized to the frequent items (i.e. + // frequent length-one prefixes) + var resultsAccumulator = freqItemCounts.map(x => (List(x._1), x._2)) + + // Remaining work to be locally and distributively processed respectfully + var (pairsForLocal, pairsForDistributed) = partitionByProjDBSize(itemSuffixPairs) + + // Continue processing until no pairs for distributed processing remain (i.e. all prefixes have + // projected database sizes <= `maxLocalProjDBSize`) + while (pairsForDistributed.count() != 0) { + val (nextPatternAndCounts, nextPrefixSuffixPairs) = + extendPrefixes(minCount, pairsForDistributed) + pairsForDistributed.unpersist() + val (smallerPairsPart, largerPairsPart) = partitionByProjDBSize(nextPrefixSuffixPairs) + pairsForDistributed = largerPairsPart + pairsForDistributed.persist(StorageLevel.MEMORY_AND_DISK) + pairsForLocal ++= smallerPairsPart + resultsAccumulator ++= nextPatternAndCounts.collect() + } + + // Process the small projected databases locally + val remainingResults = getPatternsInLocal( + minCount, sc.parallelize(pairsForLocal, 1).groupByKey()) + + (sc.parallelize(resultsAccumulator, 1) ++ remainingResults) + .map { case (pattern, count) => (pattern.toArray, count) } } + /** - * Get the minimum count (sequences count * minSupport). - * @param sequences input data set, contains a set of sequences, - * @return minimum count, + * Partitions the prefix-suffix pairs by projected database size. + * @param prefixSuffixPairs prefix (length n) and suffix pairs, + * @return prefix-suffix pairs partitioned by whether their projected database size is <= or + * greater than [[maxLocalProjDBSize]] */ - private def getMinCount(sequences: RDD[Array[Int]]): Long = { - if (minSupport == 0) 0L else math.ceil(sequences.count() * minSupport).toLong + private def partitionByProjDBSize(prefixSuffixPairs: RDD[(List[Int], Array[Int])]) + : (Array[(List[Int], Array[Int])], RDD[(List[Int], Array[Int])]) = { + val prefixToSuffixSize = prefixSuffixPairs + .aggregateByKey(0)( + seqOp = { case (count, suffix) => count + suffix.length }, + combOp = { _ + _ }) + val smallPrefixes = prefixToSuffixSize + .filter(_._2 <= maxLocalProjDBSize) + .keys + .collect() + .toSet + val small = prefixSuffixPairs.filter { case (prefix, _) => smallPrefixes.contains(prefix) } + val large = prefixSuffixPairs.filter { case (prefix, _) => !smallPrefixes.contains(prefix) } + (small.collect(), large) } /** - * Generates frequent items by filtering the input data using minimal count level. - * @param minCount the absolute minimum count - * @param sequences original sequences data - * @return array of item and count pair + * Extends all prefixes by one item from their suffix and computes the resulting frequent prefixes + * and remaining work. + * @param minCount minimum count + * @param prefixSuffixPairs prefix (length N) and suffix pairs, + * @return (frequent length N+1 extended prefix, count) pairs and (frequent length N+1 extended + * prefix, corresponding suffix) pairs. */ - private def getFreqItemAndCounts( + private def extendPrefixes( minCount: Long, - sequences: RDD[Array[Int]]): RDD[(Int, Long)] = { - sequences.flatMap(_.distinct.map((_, 1L))) + prefixSuffixPairs: RDD[(List[Int], Array[Int])]) + : (RDD[(List[Int], Long)], RDD[(List[Int], Array[Int])]) = { + + // (length N prefix, item from suffix) pairs and their corresponding number of occurrences + // Every (prefix :+ suffix) is guaranteed to have support exceeding `minSupport` + val prefixItemPairAndCounts = prefixSuffixPairs + .flatMap { case (prefix, suffix) => suffix.distinct.map(y => ((prefix, y), 1L)) } .reduceByKey(_ + _) .filter(_._2 >= minCount) - } - /** - * Get the frequent prefixes' projected database. - * @param frequentPrefixes frequent prefixes - * @param sequences sequences data - * @return prefixes and projected database - */ - private def getPrefixAndProjectedDatabase( - frequentPrefixes: Array[Int], - sequences: RDD[Array[Int]]): RDD[(Array[Int], Array[Int])] = { - val filteredSequences = sequences.map { p => - p.filter (frequentPrefixes.contains(_) ) - } - filteredSequences.flatMap { x => - frequentPrefixes.map { y => - val sub = LocalPrefixSpan.getSuffix(y, x) - (Array(y), sub) - }.filter(_._2.nonEmpty) - } + // Map from prefix to set of possible next items from suffix + val prefixToNextItems = prefixItemPairAndCounts + .keys + .groupByKey() + .mapValues(_.toSet) + .collect() + .toMap + + + // Frequent patterns with length N+1 and their corresponding counts + val extendedPrefixAndCounts = prefixItemPairAndCounts + .map { case ((prefix, item), count) => (item :: prefix, count) } + + // Remaining work, all prefixes will have length N+1 + val extendedPrefixAndSuffix = prefixSuffixPairs + .filter(x => prefixToNextItems.contains(x._1)) + .flatMap { case (prefix, suffix) => + val frequentNextItems = prefixToNextItems(prefix) + val filteredSuffix = suffix.filter(frequentNextItems.contains(_)) + frequentNextItems.flatMap { item => + LocalPrefixSpan.getSuffix(item, filteredSuffix) match { + case suffix if !suffix.isEmpty => Some(item :: prefix, suffix) + case _ => None + } + } + } + + (extendedPrefixAndCounts, extendedPrefixAndSuffix) } /** - * calculate the patterns in local. + * Calculate the patterns in local. * @param minCount the absolute minimum count - * @param data patterns and projected sequences data data + * @param data prefixes and projected sequences data data * @return patterns */ private def getPatternsInLocal( minCount: Long, - data: RDD[(Array[Int], Array[Array[Int]])]): RDD[(Array[Int], Long)] = { - data.flatMap { case (prefix, projDB) => - LocalPrefixSpan.run(minCount, maxPatternLength, prefix.toList, projDB) - .map { case (pattern: List[Int], count: Long) => (pattern.toArray.reverse, count) } + data: RDD[(List[Int], Iterable[Array[Int]])]): RDD[(List[Int], Long)] = { + data.flatMap { + case (prefix, projDB) => + LocalPrefixSpan.run(minCount, maxPatternLength, prefix.toList.reverse, projDB) + .map { case (pattern: List[Int], count: Long) => + (pattern.reverse, count) + } } } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala new file mode 100644 index 0000000000000..72d3aabc9b1f4 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala @@ -0,0 +1,154 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.impl + +import scala.collection.mutable + +import org.apache.hadoop.fs.{Path, FileSystem} + +import org.apache.spark.{SparkContext, Logging} +import org.apache.spark.storage.StorageLevel + + +/** + * This abstraction helps with persisting and checkpointing RDDs and types derived from RDDs + * (such as Graphs and DataFrames). In documentation, we use the phrase "Dataset" to refer to + * the distributed data type (RDD, Graph, etc.). + * + * Specifically, this abstraction automatically handles persisting and (optionally) checkpointing, + * as well as unpersisting and removing checkpoint files. + * + * Users should call update() when a new Dataset has been created, + * before the Dataset has been materialized. After updating [[PeriodicCheckpointer]], users are + * responsible for materializing the Dataset to ensure that persisting and checkpointing actually + * occur. + * + * When update() is called, this does the following: + * - Persist new Dataset (if not yet persisted), and put in queue of persisted Datasets. + * - Unpersist Datasets from queue until there are at most 3 persisted Datasets. + * - If using checkpointing and the checkpoint interval has been reached, + * - Checkpoint the new Dataset, and put in a queue of checkpointed Datasets. + * - Remove older checkpoints. + * + * WARNINGS: + * - This class should NOT be copied (since copies may conflict on which Datasets should be + * checkpointed). + * - This class removes checkpoint files once later Datasets have been checkpointed. + * However, references to the older Datasets will still return isCheckpointed = true. + * + * @param checkpointInterval Datasets will be checkpointed at this interval + * @param sc SparkContext for the Datasets given to this checkpointer + * @tparam T Dataset type, such as RDD[Double] + */ +private[mllib] abstract class PeriodicCheckpointer[T]( + val checkpointInterval: Int, + val sc: SparkContext) extends Logging { + + /** FIFO queue of past checkpointed Datasets */ + private val checkpointQueue = mutable.Queue[T]() + + /** FIFO queue of past persisted Datasets */ + private val persistedQueue = mutable.Queue[T]() + + /** Number of times [[update()]] has been called */ + private var updateCount = 0 + + /** + * Update with a new Dataset. Handle persistence and checkpointing as needed. + * Since this handles persistence and checkpointing, this should be called before the Dataset + * has been materialized. + * + * @param newData New Dataset created from previous Datasets in the lineage. + */ + def update(newData: T): Unit = { + persist(newData) + persistedQueue.enqueue(newData) + // We try to maintain 2 Datasets in persistedQueue to support the semantics of this class: + // Users should call [[update()]] when a new Dataset has been created, + // before the Dataset has been materialized. + while (persistedQueue.size > 3) { + val dataToUnpersist = persistedQueue.dequeue() + unpersist(dataToUnpersist) + } + updateCount += 1 + + // Handle checkpointing (after persisting) + if ((updateCount % checkpointInterval) == 0 && sc.getCheckpointDir.nonEmpty) { + // Add new checkpoint before removing old checkpoints. + checkpoint(newData) + checkpointQueue.enqueue(newData) + // Remove checkpoints before the latest one. + var canDelete = true + while (checkpointQueue.size > 1 && canDelete) { + // Delete the oldest checkpoint only if the next checkpoint exists. + if (isCheckpointed(checkpointQueue.head)) { + removeCheckpointFile() + } else { + canDelete = false + } + } + } + } + + /** Checkpoint the Dataset */ + protected def checkpoint(data: T): Unit + + /** Return true iff the Dataset is checkpointed */ + protected def isCheckpointed(data: T): Boolean + + /** + * Persist the Dataset. + * Note: This should handle checking the current [[StorageLevel]] of the Dataset. + */ + protected def persist(data: T): Unit + + /** Unpersist the Dataset */ + protected def unpersist(data: T): Unit + + /** Get list of checkpoint files for this given Dataset */ + protected def getCheckpointFiles(data: T): Iterable[String] + + /** + * Call this at the end to delete any remaining checkpoint files. + */ + def deleteAllCheckpoints(): Unit = { + while (checkpointQueue.nonEmpty) { + removeCheckpointFile() + } + } + + /** + * Dequeue the oldest checkpointed Dataset, and remove its checkpoint files. + * This prints a warning but does not fail if the files cannot be removed. + */ + private def removeCheckpointFile(): Unit = { + val old = checkpointQueue.dequeue() + // Since the old checkpoint is not deleted by Spark, we manually delete it. + val fs = FileSystem.get(sc.hadoopConfiguration) + getCheckpointFiles(old).foreach { checkpointFile => + try { + fs.delete(new Path(checkpointFile), true) + } catch { + case e: Exception => + logWarning("PeriodicCheckpointer could not remove old checkpoint file: " + + checkpointFile) + } + } + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala index 6e5dd119dd653..11a059536c50c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala @@ -17,11 +17,7 @@ package org.apache.spark.mllib.impl -import scala.collection.mutable - -import org.apache.hadoop.fs.{Path, FileSystem} - -import org.apache.spark.Logging +import org.apache.spark.SparkContext import org.apache.spark.graphx.Graph import org.apache.spark.storage.StorageLevel @@ -31,12 +27,12 @@ import org.apache.spark.storage.StorageLevel * Specifically, it automatically handles persisting and (optionally) checkpointing, as well as * unpersisting and removing checkpoint files. * - * Users should call [[PeriodicGraphCheckpointer.updateGraph()]] when a new graph has been created, + * Users should call update() when a new graph has been created, * before the graph has been materialized. After updating [[PeriodicGraphCheckpointer]], users are * responsible for materializing the graph to ensure that persisting and checkpointing actually * occur. * - * When [[PeriodicGraphCheckpointer.updateGraph()]] is called, this does the following: + * When update() is called, this does the following: * - Persist new graph (if not yet persisted), and put in queue of persisted graphs. * - Unpersist graphs from queue until there are at most 3 persisted graphs. * - If using checkpointing and the checkpoint interval has been reached, @@ -52,7 +48,7 @@ import org.apache.spark.storage.StorageLevel * Example usage: * {{{ * val (graph1, graph2, graph3, ...) = ... - * val cp = new PeriodicGraphCheckpointer(graph1, dir, 2) + * val cp = new PeriodicGraphCheckpointer(2, sc) * graph1.vertices.count(); graph1.edges.count() * // persisted: graph1 * cp.updateGraph(graph2) @@ -73,99 +69,30 @@ import org.apache.spark.storage.StorageLevel * // checkpointed: graph4 * }}} * - * @param currentGraph Initial graph * @param checkpointInterval Graphs will be checkpointed at this interval * @tparam VD Vertex descriptor type * @tparam ED Edge descriptor type * - * TODO: Generalize this for Graphs and RDDs, and move it out of MLlib. + * TODO: Move this out of MLlib? */ private[mllib] class PeriodicGraphCheckpointer[VD, ED]( - var currentGraph: Graph[VD, ED], - val checkpointInterval: Int) extends Logging { - - /** FIFO queue of past checkpointed RDDs */ - private val checkpointQueue = mutable.Queue[Graph[VD, ED]]() - - /** FIFO queue of past persisted RDDs */ - private val persistedQueue = mutable.Queue[Graph[VD, ED]]() - - /** Number of times [[updateGraph()]] has been called */ - private var updateCount = 0 - - /** - * Spark Context for the Graphs given to this checkpointer. - * NOTE: This code assumes that only one SparkContext is used for the given graphs. - */ - private val sc = currentGraph.vertices.sparkContext + checkpointInterval: Int, + sc: SparkContext) + extends PeriodicCheckpointer[Graph[VD, ED]](checkpointInterval, sc) { - updateGraph(currentGraph) + override protected def checkpoint(data: Graph[VD, ED]): Unit = data.checkpoint() - /** - * Update [[currentGraph]] with a new graph. Handle persistence and checkpointing as needed. - * Since this handles persistence and checkpointing, this should be called before the graph - * has been materialized. - * - * @param newGraph New graph created from previous graphs in the lineage. - */ - def updateGraph(newGraph: Graph[VD, ED]): Unit = { - if (newGraph.vertices.getStorageLevel == StorageLevel.NONE) { - newGraph.persist() - } - persistedQueue.enqueue(newGraph) - // We try to maintain 2 Graphs in persistedQueue to support the semantics of this class: - // Users should call [[updateGraph()]] when a new graph has been created, - // before the graph has been materialized. - while (persistedQueue.size > 3) { - val graphToUnpersist = persistedQueue.dequeue() - graphToUnpersist.unpersist(blocking = false) - } - updateCount += 1 + override protected def isCheckpointed(data: Graph[VD, ED]): Boolean = data.isCheckpointed - // Handle checkpointing (after persisting) - if ((updateCount % checkpointInterval) == 0 && sc.getCheckpointDir.nonEmpty) { - // Add new checkpoint before removing old checkpoints. - newGraph.checkpoint() - checkpointQueue.enqueue(newGraph) - // Remove checkpoints before the latest one. - var canDelete = true - while (checkpointQueue.size > 1 && canDelete) { - // Delete the oldest checkpoint only if the next checkpoint exists. - if (checkpointQueue.get(1).get.isCheckpointed) { - removeCheckpointFile() - } else { - canDelete = false - } - } + override protected def persist(data: Graph[VD, ED]): Unit = { + if (data.vertices.getStorageLevel == StorageLevel.NONE) { + data.persist() } } - /** - * Call this at the end to delete any remaining checkpoint files. - */ - def deleteAllCheckpoints(): Unit = { - while (checkpointQueue.size > 0) { - removeCheckpointFile() - } - } + override protected def unpersist(data: Graph[VD, ED]): Unit = data.unpersist(blocking = false) - /** - * Dequeue the oldest checkpointed Graph, and remove its checkpoint files. - * This prints a warning but does not fail if the files cannot be removed. - */ - private def removeCheckpointFile(): Unit = { - val old = checkpointQueue.dequeue() - // Since the old checkpoint is not deleted by Spark, we manually delete it. - val fs = FileSystem.get(sc.hadoopConfiguration) - old.getCheckpointFiles.foreach { checkpointFile => - try { - fs.delete(new Path(checkpointFile), true) - } catch { - case e: Exception => - logWarning("PeriodicGraphCheckpointer could not remove old checkpoint file: " + - checkpointFile) - } - } + override protected def getCheckpointFiles(data: Graph[VD, ED]): Iterable[String] = { + data.getCheckpointFiles } - } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointer.scala b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointer.scala new file mode 100644 index 0000000000000..f31ed2aa90a64 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointer.scala @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.impl + +import org.apache.spark.SparkContext +import org.apache.spark.rdd.RDD +import org.apache.spark.storage.StorageLevel + + +/** + * This class helps with persisting and checkpointing RDDs. + * Specifically, it automatically handles persisting and (optionally) checkpointing, as well as + * unpersisting and removing checkpoint files. + * + * Users should call update() when a new RDD has been created, + * before the RDD has been materialized. After updating [[PeriodicRDDCheckpointer]], users are + * responsible for materializing the RDD to ensure that persisting and checkpointing actually + * occur. + * + * When update() is called, this does the following: + * - Persist new RDD (if not yet persisted), and put in queue of persisted RDDs. + * - Unpersist RDDs from queue until there are at most 3 persisted RDDs. + * - If using checkpointing and the checkpoint interval has been reached, + * - Checkpoint the new RDD, and put in a queue of checkpointed RDDs. + * - Remove older checkpoints. + * + * WARNINGS: + * - This class should NOT be copied (since copies may conflict on which RDDs should be + * checkpointed). + * - This class removes checkpoint files once later RDDs have been checkpointed. + * However, references to the older RDDs will still return isCheckpointed = true. + * + * Example usage: + * {{{ + * val (rdd1, rdd2, rdd3, ...) = ... + * val cp = new PeriodicRDDCheckpointer(2, sc) + * rdd1.count(); + * // persisted: rdd1 + * cp.update(rdd2) + * rdd2.count(); + * // persisted: rdd1, rdd2 + * // checkpointed: rdd2 + * cp.update(rdd3) + * rdd3.count(); + * // persisted: rdd1, rdd2, rdd3 + * // checkpointed: rdd2 + * cp.update(rdd4) + * rdd4.count(); + * // persisted: rdd2, rdd3, rdd4 + * // checkpointed: rdd4 + * cp.update(rdd5) + * rdd5.count(); + * // persisted: rdd3, rdd4, rdd5 + * // checkpointed: rdd4 + * }}} + * + * @param checkpointInterval RDDs will be checkpointed at this interval + * @tparam T RDD element type + * + * TODO: Move this out of MLlib? + */ +private[mllib] class PeriodicRDDCheckpointer[T]( + checkpointInterval: Int, + sc: SparkContext) + extends PeriodicCheckpointer[RDD[T]](checkpointInterval, sc) { + + override protected def checkpoint(data: RDD[T]): Unit = data.checkpoint() + + override protected def isCheckpointed(data: RDD[T]): Boolean = data.isCheckpointed + + override protected def persist(data: RDD[T]): Unit = { + if (data.getStorageLevel == StorageLevel.NONE) { + data.persist() + } + } + + override protected def unpersist(data: RDD[T]): Unit = data.unpersist(blocking = false) + + override protected def getCheckpointFiles(data: RDD[T]): Iterable[String] = { + data.getCheckpointFile.map(x => x) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala index d82ba2456df1a..88914fa875990 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala @@ -154,9 +154,9 @@ private[spark] class MatrixUDT extends UserDefinedType[Matrix] { row.setByte(0, 0) row.setInt(1, sm.numRows) row.setInt(2, sm.numCols) - row.update(3, sm.colPtrs.toSeq) - row.update(4, sm.rowIndices.toSeq) - row.update(5, sm.values.toSeq) + row.update(3, new GenericArrayData(sm.colPtrs.map(_.asInstanceOf[Any]))) + row.update(4, new GenericArrayData(sm.rowIndices.map(_.asInstanceOf[Any]))) + row.update(5, new GenericArrayData(sm.values.map(_.asInstanceOf[Any]))) row.setBoolean(6, sm.isTransposed) case dm: DenseMatrix => @@ -165,7 +165,7 @@ private[spark] class MatrixUDT extends UserDefinedType[Matrix] { row.setInt(2, dm.numCols) row.setNullAt(3) row.setNullAt(4) - row.update(5, dm.values.toSeq) + row.update(5, new GenericArrayData(dm.values.map(_.asInstanceOf[Any]))) row.setBoolean(6, dm.isTransposed) } row @@ -179,14 +179,12 @@ private[spark] class MatrixUDT extends UserDefinedType[Matrix] { val tpe = row.getByte(0) val numRows = row.getInt(1) val numCols = row.getInt(2) - val values = row.getAs[Seq[Double]](5, ArrayType(DoubleType, containsNull = false)).toArray + val values = row.getArray(5).toArray.map(_.asInstanceOf[Double]) val isTransposed = row.getBoolean(6) tpe match { case 0 => - val colPtrs = - row.getAs[Seq[Int]](3, ArrayType(IntegerType, containsNull = false)).toArray - val rowIndices = - row.getAs[Seq[Int]](4, ArrayType(IntegerType, containsNull = false)).toArray + val colPtrs = row.getArray(3).toArray.map(_.asInstanceOf[Int]) + val rowIndices = row.getArray(4).toArray.map(_.asInstanceOf[Int]) new SparseMatrix(numRows, numCols, colPtrs, rowIndices, values, isTransposed) case 1 => new DenseMatrix(numRows, numCols, values, isTransposed) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/SingularValueDecomposition.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/SingularValueDecomposition.scala index 9669c364bad8f..b416d50a5631e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/SingularValueDecomposition.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/SingularValueDecomposition.scala @@ -25,3 +25,11 @@ import org.apache.spark.annotation.Experimental */ @Experimental case class SingularValueDecomposition[UType, VType](U: UType, s: Vector, V: VType) + +/** + * :: Experimental :: + * Represents QR factors. + */ +@Experimental +case class QRDecomposition[UType, VType](Q: UType, R: VType) + diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index 0cb28d78bec05..89a1818db0d1d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -187,15 +187,15 @@ private[spark] class VectorUDT extends UserDefinedType[Vector] { val row = new GenericMutableRow(4) row.setByte(0, 0) row.setInt(1, size) - row.update(2, indices.toSeq) - row.update(3, values.toSeq) + row.update(2, new GenericArrayData(indices.map(_.asInstanceOf[Any]))) + row.update(3, new GenericArrayData(values.map(_.asInstanceOf[Any]))) row case DenseVector(values) => val row = new GenericMutableRow(4) row.setByte(0, 1) row.setNullAt(1) row.setNullAt(2) - row.update(3, values.toSeq) + row.update(3, new GenericArrayData(values.map(_.asInstanceOf[Any]))) row } } @@ -209,14 +209,11 @@ private[spark] class VectorUDT extends UserDefinedType[Vector] { tpe match { case 0 => val size = row.getInt(1) - val indices = - row.getAs[Seq[Int]](2, ArrayType(IntegerType, containsNull = false)).toArray - val values = - row.getAs[Seq[Double]](3, ArrayType(DoubleType, containsNull = false)).toArray + val indices = row.getArray(2).toArray().map(_.asInstanceOf[Int]) + val values = row.getArray(3).toArray().map(_.asInstanceOf[Double]) new SparseVector(size, indices, values) case 1 => - val values = - row.getAs[Seq[Double]](3, ArrayType(DoubleType, containsNull = false)).toArray + val values = row.getArray(3).toArray().map(_.asInstanceOf[Double]) new DenseVector(values) } } @@ -637,6 +634,8 @@ class SparseVector( require(indices.length == values.length, "Sparse vectors require that the dimension of the" + s" indices match the dimension of the values. You provided ${indices.length} indices and " + s" ${values.length} values.") + require(indices.length <= size, s"You provided ${indices.length} indices and values, " + + s"which exceeds the specified vector size ${size}.") override def toString: String = s"($size,${indices.mkString("[", ",", "]")},${values.mkString("[", ",", "]")})" diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala index 1626da9c3d2ee..bfc90c9ef8527 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala @@ -22,7 +22,7 @@ import java.util.Arrays import scala.collection.mutable.ListBuffer import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, SparseVector => BSV, axpy => brzAxpy, - svd => brzSvd} + svd => brzSvd, MatrixSingularException, inv} import breeze.numerics.{sqrt => brzSqrt} import com.github.fommil.netlib.BLAS.{getInstance => blas} @@ -497,6 +497,50 @@ class RowMatrix( columnSimilaritiesDIMSUM(computeColumnSummaryStatistics().normL2.toArray, gamma) } + /** + * Compute QR decomposition for [[RowMatrix]]. The implementation is designed to optimize the QR + * decomposition (factorization) for the [[RowMatrix]] of a tall and skinny shape. + * Reference: + * Paul G. Constantine, David F. Gleich. "Tall and skinny QR factorizations in MapReduce + * architectures" ([[http://dx.doi.org/10.1145/1996092.1996103]]) + * + * @param computeQ whether to computeQ + * @return QRDecomposition(Q, R), Q = null if computeQ = false. + */ + def tallSkinnyQR(computeQ: Boolean = false): QRDecomposition[RowMatrix, Matrix] = { + val col = numCols().toInt + // split rows horizontally into smaller matrices, and compute QR for each of them + val blockQRs = rows.glom().map { partRows => + val bdm = BDM.zeros[Double](partRows.length, col) + var i = 0 + partRows.foreach { row => + bdm(i, ::) := row.toBreeze.t + i += 1 + } + breeze.linalg.qr.reduced(bdm).r + } + + // combine the R part from previous results vertically into a tall matrix + val combinedR = blockQRs.treeReduce{ (r1, r2) => + val stackedR = BDM.vertcat(r1, r2) + breeze.linalg.qr.reduced(stackedR).r + } + val finalR = Matrices.fromBreeze(combinedR.toDenseMatrix) + val finalQ = if (computeQ) { + try { + val invR = inv(combinedR) + this.multiply(Matrices.fromBreeze(invR)) + } catch { + case err: MatrixSingularException => + logWarning("R is not invertible and return Q as null") + null + } + } else { + null + } + QRDecomposition(finalQ, finalR) + } + /** * Find all similar columns using the DIMSUM sampling algorithm, described in two papers * diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala index 93290e6508529..56c549ef99cb7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala @@ -26,6 +26,7 @@ import org.apache.spark.storage.StorageLevel /** * A more compact class to represent a rating than Tuple3[Int, Int, Double]. + * @since 0.8.0 */ case class Rating(user: Int, product: Int, rating: Double) @@ -254,6 +255,7 @@ class ALS private ( /** * Top-level methods for calling Alternating Least Squares (ALS) matrix factorization. + * @since 0.8.0 */ object ALS { /** @@ -269,6 +271,7 @@ object ALS { * @param lambda regularization factor (recommended: 0.01) * @param blocks level of parallelism to split computation into * @param seed random seed + * @since 0.9.1 */ def train( ratings: RDD[Rating], @@ -293,6 +296,7 @@ object ALS { * @param iterations number of iterations of ALS (recommended: 10-20) * @param lambda regularization factor (recommended: 0.01) * @param blocks level of parallelism to split computation into + * @since 0.8.0 */ def train( ratings: RDD[Rating], @@ -315,6 +319,7 @@ object ALS { * @param rank number of features to use * @param iterations number of iterations of ALS (recommended: 10-20) * @param lambda regularization factor (recommended: 0.01) + * @since 0.8.0 */ def train(ratings: RDD[Rating], rank: Int, iterations: Int, lambda: Double) : MatrixFactorizationModel = { @@ -331,6 +336,7 @@ object ALS { * @param ratings RDD of (userID, productID, rating) pairs * @param rank number of features to use * @param iterations number of iterations of ALS (recommended: 10-20) + * @since 0.8.0 */ def train(ratings: RDD[Rating], rank: Int, iterations: Int) : MatrixFactorizationModel = { @@ -351,6 +357,7 @@ object ALS { * @param blocks level of parallelism to split computation into * @param alpha confidence parameter * @param seed random seed + * @since 0.8.1 */ def trainImplicit( ratings: RDD[Rating], @@ -377,6 +384,7 @@ object ALS { * @param lambda regularization factor (recommended: 0.01) * @param blocks level of parallelism to split computation into * @param alpha confidence parameter + * @since 0.8.1 */ def trainImplicit( ratings: RDD[Rating], @@ -401,6 +409,7 @@ object ALS { * @param iterations number of iterations of ALS (recommended: 10-20) * @param lambda regularization factor (recommended: 0.01) * @param alpha confidence parameter + * @since 0.8.1 */ def trainImplicit(ratings: RDD[Rating], rank: Int, iterations: Int, lambda: Double, alpha: Double) : MatrixFactorizationModel = { @@ -418,6 +427,7 @@ object ALS { * @param ratings RDD of (userID, productID, rating) pairs * @param rank number of features to use * @param iterations number of iterations of ALS (recommended: 10-20) + * @since 0.8.1 */ def trainImplicit(ratings: RDD[Rating], rank: Int, iterations: Int) : MatrixFactorizationModel = { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala index 43d219a49cf4e..261ca9cef0c5b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala @@ -49,6 +49,7 @@ import org.apache.spark.storage.StorageLevel * the features computed for this user. * @param productFeatures RDD of tuples where each tuple represents the productId * and the features computed for this product. + * @since 0.8.0 */ class MatrixFactorizationModel( val rank: Int, @@ -73,7 +74,9 @@ class MatrixFactorizationModel( } } - /** Predict the rating of one user for one product. */ + /** Predict the rating of one user for one product. + * @since 0.8.0 + */ def predict(user: Int, product: Int): Double = { val userVector = userFeatures.lookup(user).head val productVector = productFeatures.lookup(product).head @@ -111,6 +114,7 @@ class MatrixFactorizationModel( * * @param usersProducts RDD of (user, product) pairs. * @return RDD of Ratings. + * @since 0.9.0 */ def predict(usersProducts: RDD[(Int, Int)]): RDD[Rating] = { // Previously the partitions of ratings are only based on the given products. @@ -142,6 +146,7 @@ class MatrixFactorizationModel( /** * Java-friendly version of [[MatrixFactorizationModel.predict]]. + * @since 1.2.0 */ def predict(usersProducts: JavaPairRDD[JavaInteger, JavaInteger]): JavaRDD[Rating] = { predict(usersProducts.rdd.asInstanceOf[RDD[(Int, Int)]]).toJavaRDD() @@ -157,6 +162,7 @@ class MatrixFactorizationModel( * by score, decreasing. The first returned is the one predicted to be most strongly * recommended to the user. The score is an opaque value that indicates how strongly * recommended the product is. + * @since 1.1.0 */ def recommendProducts(user: Int, num: Int): Array[Rating] = MatrixFactorizationModel.recommend(userFeatures.lookup(user).head, productFeatures, num) @@ -173,6 +179,7 @@ class MatrixFactorizationModel( * by score, decreasing. The first returned is the one predicted to be most strongly * recommended to the product. The score is an opaque value that indicates how strongly * recommended the user is. + * @since 1.1.0 */ def recommendUsers(product: Int, num: Int): Array[Rating] = MatrixFactorizationModel.recommend(productFeatures.lookup(product).head, userFeatures, num) @@ -180,6 +187,20 @@ class MatrixFactorizationModel( protected override val formatVersion: String = "1.0" + /** + * Save this model to the given path. + * + * This saves: + * - human-readable (JSON) model metadata to path/metadata/ + * - Parquet formatted data to path/data/ + * + * The model may be loaded using [[Loader.load]]. + * + * @param sc Spark context used to save model data. + * @param path Path specifying the directory in which to save this model. + * If the directory already exists, this method throws an exception. + * @since 1.3.0 + */ override def save(sc: SparkContext, path: String): Unit = { MatrixFactorizationModel.SaveLoadV1_0.save(this, path) } @@ -191,6 +212,7 @@ class MatrixFactorizationModel( * @return [(Int, Array[Rating])] objects, where every tuple contains a userID and an array of * rating objects which contains the same userId, recommended productID and a "score" in the * rating field. Semantics of score is same as recommendProducts API + * @since 1.4.0 */ def recommendProductsForUsers(num: Int): RDD[(Int, Array[Rating])] = { MatrixFactorizationModel.recommendForAll(rank, userFeatures, productFeatures, num).map { @@ -208,6 +230,7 @@ class MatrixFactorizationModel( * @return [(Int, Array[Rating])] objects, where every tuple contains a productID and an array * of rating objects which contains the recommended userId, same productID and a "score" in the * rating field. Semantics of score is same as recommendUsers API + * @since 1.4.0 */ def recommendUsersForProducts(num: Int): RDD[(Int, Array[Rating])] = { MatrixFactorizationModel.recommendForAll(rank, productFeatures, userFeatures, num).map { @@ -218,6 +241,9 @@ class MatrixFactorizationModel( } } +/** + * @since 1.3.0 + */ object MatrixFactorizationModel extends Loader[MatrixFactorizationModel] { import org.apache.spark.mllib.util.Loader._ @@ -292,6 +318,16 @@ object MatrixFactorizationModel extends Loader[MatrixFactorizationModel] { } } + /** + * Load a model from the given path. + * + * The model should have been saved by [[Saveable.save]]. + * + * @param sc Spark context used for loading model files. + * @param path Path specifying the directory to which the model was saved. + * @return Model instance + * @since 1.3.0 + */ override def load(sc: SparkContext, path: String): MatrixFactorizationModel = { val (loadedClassName, formatVersion, _) = loadMetadata(sc, path) val classNameV1_0 = SaveLoadV1_0.thisClassName diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/KernelDensity.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/KernelDensity.scala index 58a50f9c19f14..93a6753efd4d9 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/KernelDensity.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/KernelDensity.scala @@ -37,6 +37,7 @@ import org.apache.spark.rdd.RDD * .setBandwidth(3.0) * val densities = kd.estimate(Array(-1.0, 2.0, 5.0)) * }}} + * @since 1.4.0 */ @Experimental class KernelDensity extends Serializable { @@ -51,6 +52,7 @@ class KernelDensity extends Serializable { /** * Sets the bandwidth (standard deviation) of the Gaussian kernel (default: `1.0`). + * @since 1.4.0 */ def setBandwidth(bandwidth: Double): this.type = { require(bandwidth > 0, s"Bandwidth must be positive, but got $bandwidth.") @@ -60,6 +62,7 @@ class KernelDensity extends Serializable { /** * Sets the sample to use for density estimation. + * @since 1.4.0 */ def setSample(sample: RDD[Double]): this.type = { this.sample = sample @@ -68,6 +71,7 @@ class KernelDensity extends Serializable { /** * Sets the sample to use for density estimation (for Java users). + * @since 1.4.0 */ def setSample(sample: JavaRDD[java.lang.Double]): this.type = { this.sample = sample.rdd.asInstanceOf[RDD[Double]] @@ -76,6 +80,7 @@ class KernelDensity extends Serializable { /** * Estimates probability density function at the given array of points. + * @since 1.4.0 */ def estimate(points: Array[Double]): Array[Double] = { val sample = this.sample diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala index d321cc554c1cc..62da9f2ef22a3 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala @@ -33,6 +33,7 @@ import org.apache.spark.mllib.linalg.{Vectors, Vector} * Reference: [[http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance variance-wiki]] * Zero elements (including explicit zero values) are skipped when calling add(), * to have time complexity O(nnz) instead of O(n) for each column. + * @since 1.1.0 */ @DeveloperApi class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with Serializable { @@ -52,6 +53,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S * * @param sample The sample in dense/sparse vector format to be added into this summarizer. * @return This MultivariateOnlineSummarizer object. + * @since 1.1.0 */ def add(sample: Vector): this.type = { if (n == 0) { @@ -107,6 +109,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S * * @param other The other MultivariateOnlineSummarizer to be merged. * @return This MultivariateOnlineSummarizer object. + * @since 1.1.0 */ def merge(other: MultivariateOnlineSummarizer): this.type = { if (this.totalCnt != 0 && other.totalCnt != 0) { @@ -149,6 +152,9 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S this } + /** + * @since 1.1.0 + */ override def mean: Vector = { require(totalCnt > 0, s"Nothing has been added to this summarizer.") @@ -161,6 +167,9 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S Vectors.dense(realMean) } + /** + * @since 1.1.0 + */ override def variance: Vector = { require(totalCnt > 0, s"Nothing has been added to this summarizer.") @@ -183,14 +192,23 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S Vectors.dense(realVariance) } + /** + * @since 1.1.0 + */ override def count: Long = totalCnt + /** + * @since 1.1.0 + */ override def numNonzeros: Vector = { require(totalCnt > 0, s"Nothing has been added to this summarizer.") Vectors.dense(nnz) } + /** + * @since 1.1.0 + */ override def max: Vector = { require(totalCnt > 0, s"Nothing has been added to this summarizer.") @@ -202,6 +220,9 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S Vectors.dense(currMax) } + /** + * @since 1.1.0 + */ override def min: Vector = { require(totalCnt > 0, s"Nothing has been added to this summarizer.") @@ -213,6 +234,9 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S Vectors.dense(currMin) } + /** + * @since 1.2.0 + */ override def normL2: Vector = { require(totalCnt > 0, s"Nothing has been added to this summarizer.") @@ -227,6 +251,9 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S Vectors.dense(realMagnitude) } + /** + * @since 1.2.0 + */ override def normL1: Vector = { require(totalCnt > 0, s"Nothing has been added to this summarizer.") diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateStatisticalSummary.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateStatisticalSummary.scala index 6a364c93284af..3bb49f12289e1 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateStatisticalSummary.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateStatisticalSummary.scala @@ -21,46 +21,55 @@ import org.apache.spark.mllib.linalg.Vector /** * Trait for multivariate statistical summary of a data matrix. + * @since 1.0.0 */ trait MultivariateStatisticalSummary { /** * Sample mean vector. + * @since 1.0.0 */ def mean: Vector /** * Sample variance vector. Should return a zero vector if the sample size is 1. + * @since 1.0.0 */ def variance: Vector /** * Sample size. + * @since 1.0.0 */ def count: Long /** * Number of nonzero elements (including explicitly presented zero values) in each column. + * @since 1.0.0 */ def numNonzeros: Vector /** * Maximum value of each column. + * @since 1.0.0 */ def max: Vector /** * Minimum value of each column. + * @since 1.0.0 */ def min: Vector /** * Euclidean magnitude of each column + * @since 1.2.0 */ def normL2: Vector /** * L1 norm of each column + * @since 1.2.0 */ def normL1: Vector } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala index 90332028cfb3a..f84502919e381 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala @@ -32,6 +32,7 @@ import org.apache.spark.rdd.RDD /** * :: Experimental :: * API for statistical functions in MLlib. + * @since 1.1.0 */ @Experimental object Statistics { @@ -41,6 +42,7 @@ object Statistics { * * @param X an RDD[Vector] for which column-wise summary statistics are to be computed. * @return [[MultivariateStatisticalSummary]] object containing column-wise summary statistics. + * @since 1.1.0 */ def colStats(X: RDD[Vector]): MultivariateStatisticalSummary = { new RowMatrix(X).computeColumnSummaryStatistics() @@ -52,6 +54,7 @@ object Statistics { * * @param X an RDD[Vector] for which the correlation matrix is to be computed. * @return Pearson correlation matrix comparing columns in X. + * @since 1.1.0 */ def corr(X: RDD[Vector]): Matrix = Correlations.corrMatrix(X) @@ -68,6 +71,7 @@ object Statistics { * @param method String specifying the method to use for computing correlation. * Supported: `pearson` (default), `spearman` * @return Correlation matrix comparing columns in X. + * @since 1.1.0 */ def corr(X: RDD[Vector], method: String): Matrix = Correlations.corrMatrix(X, method) @@ -81,10 +85,14 @@ object Statistics { * @param x RDD[Double] of the same cardinality as y. * @param y RDD[Double] of the same cardinality as x. * @return A Double containing the Pearson correlation between the two input RDD[Double]s + * @since 1.1.0 */ def corr(x: RDD[Double], y: RDD[Double]): Double = Correlations.corr(x, y) - /** Java-friendly version of [[corr()]] */ + /** + * Java-friendly version of [[corr()]] + * @since 1.4.1 + */ def corr(x: JavaRDD[java.lang.Double], y: JavaRDD[java.lang.Double]): Double = corr(x.rdd.asInstanceOf[RDD[Double]], y.rdd.asInstanceOf[RDD[Double]]) @@ -101,10 +109,14 @@ object Statistics { * Supported: `pearson` (default), `spearman` * @return A Double containing the correlation between the two input RDD[Double]s using the * specified method. + * @since 1.1.0 */ def corr(x: RDD[Double], y: RDD[Double], method: String): Double = Correlations.corr(x, y, method) - /** Java-friendly version of [[corr()]] */ + /** + * Java-friendly version of [[corr()]] + * @since 1.4.1 + */ def corr(x: JavaRDD[java.lang.Double], y: JavaRDD[java.lang.Double], method: String): Double = corr(x.rdd.asInstanceOf[RDD[Double]], y.rdd.asInstanceOf[RDD[Double]], method) @@ -121,6 +133,7 @@ object Statistics { * `expected` is rescaled if the `expected` sum differs from the `observed` sum. * @return ChiSquaredTest object containing the test statistic, degrees of freedom, p-value, * the method used, and the null hypothesis. + * @since 1.1.0 */ def chiSqTest(observed: Vector, expected: Vector): ChiSqTestResult = { ChiSqTest.chiSquared(observed, expected) @@ -135,6 +148,7 @@ object Statistics { * @param observed Vector containing the observed categorical counts/relative frequencies. * @return ChiSquaredTest object containing the test statistic, degrees of freedom, p-value, * the method used, and the null hypothesis. + * @since 1.1.0 */ def chiSqTest(observed: Vector): ChiSqTestResult = ChiSqTest.chiSquared(observed) @@ -145,6 +159,7 @@ object Statistics { * @param observed The contingency matrix (containing either counts or relative frequencies). * @return ChiSquaredTest object containing the test statistic, degrees of freedom, p-value, * the method used, and the null hypothesis. + * @since 1.1.0 */ def chiSqTest(observed: Matrix): ChiSqTestResult = ChiSqTest.chiSquaredMatrix(observed) @@ -157,6 +172,7 @@ object Statistics { * Real-valued features will be treated as categorical for each distinct value. * @return an array containing the ChiSquaredTestResult for every feature against the label. * The order of the elements in the returned array reflects the order of input features. + * @since 1.1.0 */ def chiSqTest(data: RDD[LabeledPoint]): Array[ChiSqTestResult] = { ChiSqTest.chiSquaredFeatures(data) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala index cf51b24ff777f..9aa7763d7890d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala @@ -32,6 +32,7 @@ import org.apache.spark.mllib.util.MLUtils * * @param mu The mean vector of the distribution * @param sigma The covariance matrix of the distribution + * @since 1.3.0 */ @DeveloperApi class MultivariateGaussian ( @@ -60,12 +61,16 @@ class MultivariateGaussian ( */ private val (rootSigmaInv: DBM[Double], u: Double) = calculateCovarianceConstants - /** Returns density of this multivariate Gaussian at given point, x */ + /** Returns density of this multivariate Gaussian at given point, x + * @since 1.3.0 + */ def pdf(x: Vector): Double = { pdf(x.toBreeze) } - /** Returns the log-density of this multivariate Gaussian at given point, x */ + /** Returns the log-density of this multivariate Gaussian at given point, x + * @since 1.3.0 + */ def logpdf(x: Vector): Double = { logpdf(x.toBreeze) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala index a835f96d5d0e3..9ce6faa137c41 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala @@ -20,6 +20,7 @@ package org.apache.spark.mllib.tree import org.apache.spark.Logging import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaRDD +import org.apache.spark.mllib.impl.PeriodicRDDCheckpointer import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.BoostingStrategy import org.apache.spark.mllib.tree.configuration.Algo._ @@ -184,22 +185,28 @@ object GradientBoostedTrees extends Logging { false } + // Prepare periodic checkpointers + val predErrorCheckpointer = new PeriodicRDDCheckpointer[(Double, Double)]( + treeStrategy.getCheckpointInterval, input.sparkContext) + val validatePredErrorCheckpointer = new PeriodicRDDCheckpointer[(Double, Double)]( + treeStrategy.getCheckpointInterval, input.sparkContext) + timer.stop("init") logDebug("##########") logDebug("Building tree 0") logDebug("##########") - var data = input // Initialize tree timer.start("building tree 0") - val firstTreeModel = new DecisionTree(treeStrategy).run(data) + val firstTreeModel = new DecisionTree(treeStrategy).run(input) val firstTreeWeight = 1.0 baseLearners(0) = firstTreeModel baseLearnerWeights(0) = firstTreeWeight var predError: RDD[(Double, Double)] = GradientBoostedTreesModel. computeInitialPredictionAndError(input, firstTreeWeight, firstTreeModel, loss) + predErrorCheckpointer.update(predError) logDebug("error of gbt = " + predError.values.mean()) // Note: A model of type regression is used since we require raw prediction @@ -207,35 +214,34 @@ object GradientBoostedTrees extends Logging { var validatePredError: RDD[(Double, Double)] = GradientBoostedTreesModel. computeInitialPredictionAndError(validationInput, firstTreeWeight, firstTreeModel, loss) + if (validate) validatePredErrorCheckpointer.update(validatePredError) var bestValidateError = if (validate) validatePredError.values.mean() else 0.0 var bestM = 1 - // pseudo-residual for second iteration - data = predError.zip(input).map { case ((pred, _), point) => - LabeledPoint(-loss.gradient(pred, point.label), point.features) - } - var m = 1 - while (m < numIterations) { + var doneLearning = false + while (m < numIterations && !doneLearning) { + // Update data with pseudo-residuals + val data = predError.zip(input).map { case ((pred, _), point) => + LabeledPoint(-loss.gradient(pred, point.label), point.features) + } + timer.start(s"building tree $m") logDebug("###################################################") logDebug("Gradient boosting tree iteration " + m) logDebug("###################################################") val model = new DecisionTree(treeStrategy).run(data) timer.stop(s"building tree $m") - // Create partial model + // Update partial model baseLearners(m) = model // Note: The setting of baseLearnerWeights is incorrect for losses other than SquaredError. // Technically, the weight should be optimized for the particular loss. // However, the behavior should be reasonable, though not optimal. baseLearnerWeights(m) = learningRate - // Note: A model of type regression is used since we require raw prediction - val partialModel = new GradientBoostedTreesModel( - Regression, baseLearners.slice(0, m + 1), - baseLearnerWeights.slice(0, m + 1)) predError = GradientBoostedTreesModel.updatePredictionError( input, predError, baseLearnerWeights(m), baseLearners(m), loss) + predErrorCheckpointer.update(predError) logDebug("error of gbt = " + predError.values.mean()) if (validate) { @@ -246,21 +252,15 @@ object GradientBoostedTrees extends Logging { validatePredError = GradientBoostedTreesModel.updatePredictionError( validationInput, validatePredError, baseLearnerWeights(m), baseLearners(m), loss) + validatePredErrorCheckpointer.update(validatePredError) val currentValidateError = validatePredError.values.mean() if (bestValidateError - currentValidateError < validationTol) { - return new GradientBoostedTreesModel( - boostingStrategy.treeStrategy.algo, - baseLearners.slice(0, bestM), - baseLearnerWeights.slice(0, bestM)) + doneLearning = true } else if (currentValidateError < bestValidateError) { - bestValidateError = currentValidateError - bestM = m + 1 + bestValidateError = currentValidateError + bestM = m + 1 } } - // Update data with pseudo-residuals - data = predError.zip(input).map { case ((pred, _), point) => - LabeledPoint(-loss.gradient(pred, point.label), point.features) - } m += 1 } @@ -269,6 +269,8 @@ object GradientBoostedTrees extends Logging { logInfo("Internal timing for DecisionTree:") logInfo(s"$timer") + predErrorCheckpointer.deleteAllCheckpoints() + validatePredErrorCheckpointer.deleteAllCheckpoints() if (persistedInput) input.unpersist() if (validate) { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala index 2d6b01524ff3d..9fd30c9b56319 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala @@ -36,7 +36,8 @@ import org.apache.spark.mllib.tree.loss.{LogLoss, SquaredError, Loss} * learning rate should be between in the interval (0, 1] * @param validationTol Useful when runWithValidation is used. If the error rate on the * validation input between two iterations is less than the validationTol - * then stop. Ignored when [[run]] is used. + * then stop. Ignored when + * [[org.apache.spark.mllib.tree.GradientBoostedTrees.run()]] is used. */ @Experimental case class BoostingStrategy( diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala index 380291ac22bd3..9fe264656ede7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala @@ -128,9 +128,13 @@ private[spark] object DecisionTreeMetadata extends Logging { // based on the number of training examples. if (strategy.categoricalFeaturesInfo.nonEmpty) { val maxCategoriesPerFeature = strategy.categoricalFeaturesInfo.values.max + val maxCategory = + strategy.categoricalFeaturesInfo.find(_._2 == maxCategoriesPerFeature).get._1 require(maxCategoriesPerFeature <= maxPossibleBins, - s"DecisionTree requires maxBins (= $maxPossibleBins) >= max categories " + - s"in categorical features (= $maxCategoriesPerFeature)") + s"DecisionTree requires maxBins (= $maxPossibleBins) to be at least as large as the " + + s"number of values in each categorical feature, but categorical feature $maxCategory " + + s"has $maxCategoriesPerFeature values. Considering remove this and other categorical " + + "features with a large number of values, or add more training examples.") } val unorderedFeatures = new mutable.HashSet[Int]() diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java index b48f190f599a2..d272a42c8576f 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java @@ -19,6 +19,7 @@ import java.io.Serializable; import java.util.ArrayList; +import java.util.Arrays; import scala.Tuple2; @@ -59,7 +60,10 @@ public void tearDown() { @Test public void localLDAModel() { - LocalLDAModel model = new LocalLDAModel(LDASuite$.MODULE$.tinyTopics()); + Matrix topics = LDASuite$.MODULE$.tinyTopics(); + double[] topicConcentration = new double[topics.numRows()]; + Arrays.fill(topicConcentration, 1.0D / topics.numRows()); + LocalLDAModel model = new LocalLDAModel(topics, Vectors.dense(topicConcentration), 1D, 100D); // Check: basic parameters assertEquals(model.k(), tinyK); diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala index 82c345491bb3c..a7bc77965fefd 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala @@ -28,6 +28,7 @@ import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame +import org.apache.spark.util.Utils /** @@ -76,6 +77,25 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext { } } + test("Checkpointing") { + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + sc.setCheckpointDir(path) + + val categoricalFeatures = Map.empty[Int, Int] + val df: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 2) + val gbt = new GBTClassifier() + .setMaxDepth(2) + .setLossType("logistic") + .setMaxIter(5) + .setStepSize(0.1) + .setCheckpointInterval(2) + val model = gbt.fit(df) + + sc.checkpointDir = None + Utils.deleteRecursively(tempDir) + } + // TODO: Reinstate test once runWithValidation is implemented SPARK-7132 /* test("runWithValidation stops early and performs better on a validation dataset") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala index 1b6b69c7dc71e..ab711c8e4b215 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala @@ -21,13 +21,13 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.ml.impl.TreeTests import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.tree.LeafNode -import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{EnsembleTestHelper, RandomForest => OldRandomForest} import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Row} /** * Test suite for [[RandomForestClassifier]]. @@ -66,7 +66,7 @@ class RandomForestClassifierSuite extends SparkFunSuite with MLlibTestSparkConte test("params") { ParamsSuite.checkParams(new RandomForestClassifier) val model = new RandomForestClassificationModel("rfc", - Array(new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0)))) + Array(new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0))), 2) ParamsSuite.checkParams(model) } @@ -167,9 +167,19 @@ private object RandomForestClassifierSuite { val newModel = rf.fit(newData) // Use parent from newTree since this is not checked anyways. val oldModelAsNew = RandomForestClassificationModel.fromOld( - oldModel, newModel.parent.asInstanceOf[RandomForestClassifier], categoricalFeatures) + oldModel, newModel.parent.asInstanceOf[RandomForestClassifier], categoricalFeatures, + numClasses) TreeTests.checkEqual(oldModelAsNew, newModel) assert(newModel.hasParent) assert(!newModel.trees.head.asInstanceOf[DecisionTreeClassificationModel].hasParent) + assert(newModel.numClasses == numClasses) + val results = newModel.transform(newData) + results.select("rawPrediction", "prediction").collect().foreach { + case Row(raw: Vector, prediction: Double) => { + assert(raw.size == numClasses) + val predFromRaw = raw.toArray.zipWithIndex.maxBy(_._1)._2 + assert(predFromRaw == prediction) + } + } } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala index 65846a846b7b4..321eeb843941c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala @@ -86,8 +86,8 @@ class OneHotEncoderSuite extends SparkFunSuite with MLlibTestSparkContext { val output = encoder.transform(df) val group = AttributeGroup.fromStructField(output.schema("encoded")) assert(group.size === 2) - assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("size_is_small").withIndex(0)) - assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("size_is_medium").withIndex(1)) + assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("small").withIndex(0)) + assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("medium").withIndex(1)) } test("input column without ML attribute") { @@ -98,7 +98,7 @@ class OneHotEncoderSuite extends SparkFunSuite with MLlibTestSparkContext { val output = encoder.transform(df) val group = AttributeGroup.fromStructField(output.schema("encoded")) assert(group.size === 2) - assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("index_is_0").withIndex(0)) - assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("index_is_1").withIndex(1)) + assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("0").withIndex(0)) + assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("1").withIndex(1)) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala index c4b45aee06384..436e66bab09b0 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala @@ -18,12 +18,17 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.types._ class RFormulaParserSuite extends SparkFunSuite { - private def checkParse(formula: String, label: String, terms: Seq[String]) { - val parsed = RFormulaParser.parse(formula) - assert(parsed.label == label) - assert(parsed.terms == terms) + private def checkParse( + formula: String, + label: String, + terms: Seq[String], + schema: StructType = null) { + val resolved = RFormulaParser.parse(formula).resolve(schema) + assert(resolved.label == label) + assert(resolved.terms == terms) } test("parse simple formulas") { @@ -32,4 +37,46 @@ class RFormulaParserSuite extends SparkFunSuite { checkParse("y ~ ._foo ", "y", Seq("._foo")) checkParse("resp ~ A_VAR + B + c123", "resp", Seq("A_VAR", "B", "c123")) } + + test("parse dot") { + val schema = (new StructType) + .add("a", "int", true) + .add("b", "long", false) + .add("c", "string", true) + checkParse("a ~ .", "a", Seq("b", "c"), schema) + } + + test("parse deletion") { + val schema = (new StructType) + .add("a", "int", true) + .add("b", "long", false) + .add("c", "string", true) + checkParse("a ~ c - b", "a", Seq("c"), schema) + } + + test("parse additions and deletions in order") { + val schema = (new StructType) + .add("a", "int", true) + .add("b", "long", false) + .add("c", "string", true) + checkParse("a ~ . - b + . - c", "a", Seq("b"), schema) + } + + test("dot ignores complex column types") { + val schema = (new StructType) + .add("a", "int", true) + .add("b", "tinyint", false) + .add("c", "map", true) + checkParse("a ~ .", "a", Seq("b"), schema) + } + + test("parse intercept") { + assert(RFormulaParser.parse("a ~ b").hasIntercept) + assert(RFormulaParser.parse("a ~ b + 1").hasIntercept) + assert(RFormulaParser.parse("a ~ b - 0").hasIntercept) + assert(RFormulaParser.parse("a ~ b - 1 + 1").hasIntercept) + assert(!RFormulaParser.parse("a ~ b + 0").hasIntercept) + assert(!RFormulaParser.parse("a ~ b - 1").hasIntercept) + assert(!RFormulaParser.parse("a ~ b + 1 - 1").hasIntercept) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala index 8148c553e9051..6aed3243afce8 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.attribute._ import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.MLlibTestSparkContext @@ -105,4 +106,21 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext { assert(result.schema.toString == resultSchema.toString) assert(result.collect() === expected.collect()) } + + test("attribute generation") { + val formula = new RFormula().setFormula("id ~ a + b") + val original = sqlContext.createDataFrame( + Seq((1, "foo", 4), (2, "bar", 4), (3, "bar", 5), (4, "baz", 5)) + ).toDF("id", "a", "b") + val model = formula.fit(original) + val result = model.transform(original) + val attrs = AttributeGroup.fromStructField(result.schema("features")) + val expectedAttrs = new AttributeGroup( + "features", + Array( + new BinaryAttribute(Some("a__bar"), Some(1)), + new BinaryAttribute(Some("a__foo"), Some(2)), + new NumericAttribute(Some("b"), Some(3)))) + assert(attrs === expectedAttrs) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala index 9682edcd9ba84..dbdce0c9dea54 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala @@ -25,7 +25,8 @@ import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.DataFrame +import org.apache.spark.util.Utils /** @@ -88,6 +89,23 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext { assert(predictions.min() < -1) } + test("Checkpointing") { + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + sc.setCheckpointDir(path) + + val df = sqlContext.createDataFrame(data) + val gbt = new GBTRegressor() + .setMaxDepth(2) + .setMaxIter(5) + .setStepSize(0.1) + .setCheckpointInterval(2) + val model = gbt.fit(df) + + sc.checkpointDir = None + Utils.deleteRecursively(tempDir) + } + // TODO: Reinstate test once runWithValidation is implemented SPARK-7132 /* test("runWithValidation stops early and performs better on a validation dataset") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala new file mode 100644 index 0000000000000..66e4b170bae80 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.regression + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.types.{DoubleType, StructField, StructType} +import org.apache.spark.sql.{DataFrame, Row} + +class IsotonicRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { + private val schema = StructType( + Array( + StructField("label", DoubleType), + StructField("features", DoubleType), + StructField("weight", DoubleType))) + + private val predictionSchema = StructType(Array(StructField("features", DoubleType))) + + private def generateIsotonicInput(labels: Seq[Double]): DataFrame = { + val data = Seq.tabulate(labels.size)(i => Row(labels(i), i.toDouble, 1d)) + val parallelData = sc.parallelize(data) + + sqlContext.createDataFrame(parallelData, schema) + } + + private def generatePredictionInput(features: Seq[Double]): DataFrame = { + val data = Seq.tabulate(features.size)(i => Row(features(i))) + + val parallelData = sc.parallelize(data) + sqlContext.createDataFrame(parallelData, predictionSchema) + } + + test("isotonic regression predictions") { + val dataset = generateIsotonicInput(Seq(1, 2, 3, 1, 6, 17, 16, 17, 18)) + val trainer = new IsotonicRegression().setIsotonicParam(true) + + val model = trainer.fit(dataset) + + val predictions = model + .transform(dataset) + .select("prediction").map { + case Row(pred) => pred + }.collect() + + assert(predictions === Array(1, 2, 2, 2, 6, 16.5, 16.5, 17, 18)) + + assert(model.parentModel.boundaries === Array(0, 1, 3, 4, 5, 6, 7, 8)) + assert(model.parentModel.predictions === Array(1, 2, 2, 6, 16.5, 16.5, 17.0, 18.0)) + assert(model.parentModel.isotonic) + } + + test("antitonic regression predictions") { + val dataset = generateIsotonicInput(Seq(7, 5, 3, 5, 1)) + val trainer = new IsotonicRegression().setIsotonicParam(false) + + val model = trainer.fit(dataset) + val features = generatePredictionInput(Seq(-2.0, -1.0, 0.5, 0.75, 1.0, 2.0, 9.0)) + + val predictions = model + .transform(features) + .select("prediction").map { + case Row(pred) => pred + }.collect() + + assert(predictions === Array(7, 7, 6, 5.5, 5, 4, 1)) + } + + test("params validation") { + val dataset = generateIsotonicInput(Seq(1, 2, 3)) + val ir = new IsotonicRegression + ParamsSuite.checkParams(ir) + val model = ir.fit(dataset) + ParamsSuite.checkParams(model) + } + + test("default params") { + val dataset = generateIsotonicInput(Seq(1, 2, 3)) + val ir = new IsotonicRegression() + assert(ir.getLabelCol === "label") + assert(ir.getFeaturesCol === "features") + assert(ir.getWeightCol === "weight") + assert(ir.getPredictionCol === "prediction") + assert(ir.getIsotonicParam === true) + + val model = ir.fit(dataset) + model.transform(dataset) + .select("label", "features", "prediction", "weight") + .collect() + + assert(model.getLabelCol === "label") + assert(model.getFeaturesCol === "features") + assert(model.getWeightCol === "weight") + assert(model.getPredictionCol === "prediction") + assert(model.getIsotonicParam === true) + assert(model.hasParent) + } + + test("set parameters") { + val isotonicRegression = new IsotonicRegression() + .setIsotonicParam(false) + .setWeightParam("w") + .setFeaturesCol("f") + .setLabelCol("l") + .setPredictionCol("p") + + assert(isotonicRegression.getIsotonicParam === false) + assert(isotonicRegression.getWeightCol === "w") + assert(isotonicRegression.getFeaturesCol === "f") + assert(isotonicRegression.getLabelCol === "l") + assert(isotonicRegression.getPredictionCol === "p") + } + + test("missing column") { + val dataset = generateIsotonicInput(Seq(1, 2, 3)) + + intercept[IllegalArgumentException] { + new IsotonicRegression().setWeightParam("w").fit(dataset) + } + + intercept[IllegalArgumentException] { + new IsotonicRegression().setFeaturesCol("f").fit(dataset) + } + + intercept[IllegalArgumentException] { + new IsotonicRegression().setLabelCol("l").fit(dataset) + } + + intercept[IllegalArgumentException] { + new IsotonicRegression().fit(dataset).setFeaturesCol("f").transform(dataset) + } + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala index fd653296c9d97..d7b291d5a6330 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala @@ -24,13 +24,22 @@ import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.streaming.dstream.DStream -import org.apache.spark.streaming.TestSuiteBase +import org.apache.spark.streaming.{StreamingContext, TestSuiteBase} class StreamingLogisticRegressionSuite extends SparkFunSuite with TestSuiteBase { // use longer wait time to ensure job completion override def maxWaitTimeMillis: Int = 30000 + var ssc: StreamingContext = _ + + override def afterFunction() { + super.afterFunction() + if (ssc != null) { + ssc.stop() + } + } + // Test if we can accurately learn B for Y = logistic(BX) on streaming data test("parameter accuracy") { @@ -50,7 +59,7 @@ class StreamingLogisticRegressionSuite extends SparkFunSuite with TestSuiteBase } // apply model training to input stream - val ssc = setupStreams(input, (inputDStream: DStream[LabeledPoint]) => { + ssc = setupStreams(input, (inputDStream: DStream[LabeledPoint]) => { model.trainOn(inputDStream) inputDStream.count() }) @@ -84,7 +93,7 @@ class StreamingLogisticRegressionSuite extends SparkFunSuite with TestSuiteBase // apply model training to input stream, storing the intermediate results // (we add a count to ensure the result is a DStream) - val ssc = setupStreams(input, (inputDStream: DStream[LabeledPoint]) => { + ssc = setupStreams(input, (inputDStream: DStream[LabeledPoint]) => { model.trainOn(inputDStream) inputDStream.foreachRDD(x => history.append(math.abs(model.latestModel().weights(0) - B))) inputDStream.count() @@ -118,7 +127,7 @@ class StreamingLogisticRegressionSuite extends SparkFunSuite with TestSuiteBase } // apply model predictions to test stream - val ssc = setupStreams(testInput, (inputDStream: DStream[LabeledPoint]) => { + ssc = setupStreams(testInput, (inputDStream: DStream[LabeledPoint]) => { model.predictOnValues(inputDStream.map(x => (x.label, x.features))) }) @@ -147,7 +156,7 @@ class StreamingLogisticRegressionSuite extends SparkFunSuite with TestSuiteBase } // train and predict - val ssc = setupStreams(testInput, (inputDStream: DStream[LabeledPoint]) => { + ssc = setupStreams(testInput, (inputDStream: DStream[LabeledPoint]) => { model.trainOn(inputDStream) model.predictOnValues(inputDStream.map(x => (x.label, x.features))) }) @@ -167,7 +176,7 @@ class StreamingLogisticRegressionSuite extends SparkFunSuite with TestSuiteBase .setNumIterations(10) val numBatches = 10 val emptyInput = Seq.empty[Seq[LabeledPoint]] - val ssc = setupStreams(emptyInput, + ssc = setupStreams(emptyInput, (inputDStream: DStream[LabeledPoint]) => { model.trainOn(inputDStream) model.predictOnValues(inputDStream.map(x => (x.label, x.features))) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala index 376a87f0511b4..c43e1e575c09c 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.mllib.clustering -import breeze.linalg.{DenseMatrix => BDM} +import breeze.linalg.{DenseMatrix => BDM, max, argmax} import org.apache.spark.SparkFunSuite import org.apache.spark.graphx.Edge @@ -31,7 +31,8 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext { import LDASuite._ test("LocalLDAModel") { - val model = new LocalLDAModel(tinyTopics) + val model = new LocalLDAModel(tinyTopics, + Vectors.dense(Array.fill(tinyTopics.numRows)(1.0 / tinyTopics.numRows)), 1D, 100D) // Check: basic parameters assert(model.k === tinyK) @@ -82,21 +83,14 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext { assert(model.topicsMatrix === localModel.topicsMatrix) // Check: topic summaries - // The odd decimal formatting and sorting is a hack to do a robust comparison. - val roundedTopicSummary = model.describeTopics().map { case (terms, termWeights) => - // cut values to 3 digits after the decimal place - terms.zip(termWeights).map { case (term, weight) => - ("%.3f".format(weight).toDouble, term.toInt) - } - }.sortBy(_.mkString("")) - val roundedLocalTopicSummary = localModel.describeTopics().map { case (terms, termWeights) => - // cut values to 3 digits after the decimal place - terms.zip(termWeights).map { case (term, weight) => - ("%.3f".format(weight).toDouble, term.toInt) - } - }.sortBy(_.mkString("")) - roundedTopicSummary.zip(roundedLocalTopicSummary).foreach { case (t1, t2) => - assert(t1 === t2) + val topicSummary = model.describeTopics().map { case (terms, termWeights) => + Vectors.sparse(tinyVocabSize, terms, termWeights) + }.sortBy(_.toString) + val localTopicSummary = localModel.describeTopics().map { case (terms, termWeights) => + Vectors.sparse(tinyVocabSize, terms, termWeights) + }.sortBy(_.toString) + topicSummary.zip(localTopicSummary).foreach { case (topics, topicsLocal) => + assert(topics ~== topicsLocal absTol 0.01) } // Check: per-doc topic distributions @@ -196,10 +190,12 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext { // verify the result, Note this generate the identical result as // [[https://github.com/Blei-Lab/onlineldavb]] - val topic1 = op.getLambda(0, ::).inner.toArray.map("%.4f".format(_)).mkString(", ") - val topic2 = op.getLambda(1, ::).inner.toArray.map("%.4f".format(_)).mkString(", ") - assert("1.1101, 1.2076, 1.3050, 0.8899, 0.7924, 0.6950" == topic1) - assert("0.8899, 0.7924, 0.6950, 1.1101, 1.2076, 1.3050" == topic2) + val topic1: Vector = Vectors.fromBreeze(op.getLambda(0, ::).t) + val topic2: Vector = Vectors.fromBreeze(op.getLambda(1, ::).t) + val expectedTopic1 = Vectors.dense(1.1101, 1.2076, 1.3050, 0.8899, 0.7924, 0.6950) + val expectedTopic2 = Vectors.dense(0.8899, 0.7924, 0.6950, 1.1101, 1.2076, 1.3050) + assert(topic1 ~== expectedTopic1 absTol 0.01) + assert(topic2 ~== expectedTopic2 absTol 0.01) } test("OnlineLDAOptimizer with toy data") { @@ -235,6 +231,114 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext { } } + test("LocalLDAModel logPerplexity") { + val k = 2 + val vocabSize = 6 + val alpha = 0.01 + val eta = 0.01 + val gammaShape = 100 + // obtained from LDA model trained in gensim, see below + val topics = new DenseMatrix(numRows = vocabSize, numCols = k, values = Array( + 1.86738052, 1.94056535, 1.89981687, 0.0833265, 0.07405918, 0.07940597, + 0.15081551, 0.08637973, 0.12428538, 1.9474897, 1.94615165, 1.95204124)) + + def toydata: Array[(Long, Vector)] = Array( + Vectors.sparse(6, Array(0, 1), Array(1, 1)), + Vectors.sparse(6, Array(1, 2), Array(1, 1)), + Vectors.sparse(6, Array(0, 2), Array(1, 1)), + Vectors.sparse(6, Array(3, 4), Array(1, 1)), + Vectors.sparse(6, Array(3, 5), Array(1, 1)), + Vectors.sparse(6, Array(4, 5), Array(1, 1)) + ).zipWithIndex.map { case (wordCounts, docId) => (docId.toLong, wordCounts) } + val docs = sc.parallelize(toydata) + + + val ldaModel: LocalLDAModel = new LocalLDAModel( + topics, Vectors.dense(Array.fill(k)(alpha)), eta, gammaShape) + + /* Verify results using gensim: + import numpy as np + from gensim import models + corpus = [ + [(0, 1.0), (1, 1.0)], + [(1, 1.0), (2, 1.0)], + [(0, 1.0), (2, 1.0)], + [(3, 1.0), (4, 1.0)], + [(3, 1.0), (5, 1.0)], + [(4, 1.0), (5, 1.0)]] + np.random.seed(2345) + lda = models.ldamodel.LdaModel( + corpus=corpus, alpha=0.01, eta=0.01, num_topics=2, update_every=0, passes=100, + decay=0.51, offset=1024) + print(lda.log_perplexity(corpus)) + > -3.69051285096 + */ + + assert(ldaModel.logPerplexity(docs) ~== -3.690D relTol 1E-3D) + } + + test("LocalLDAModel predict") { + val k = 2 + val vocabSize = 6 + val alpha = 0.01 + val eta = 0.01 + val gammaShape = 100 + // obtained from LDA model trained in gensim, see below + val topics = new DenseMatrix(numRows = vocabSize, numCols = k, values = Array( + 1.86738052, 1.94056535, 1.89981687, 0.0833265, 0.07405918, 0.07940597, + 0.15081551, 0.08637973, 0.12428538, 1.9474897, 1.94615165, 1.95204124)) + + def toydata: Array[(Long, Vector)] = Array( + Vectors.sparse(6, Array(0, 1), Array(1, 1)), + Vectors.sparse(6, Array(1, 2), Array(1, 1)), + Vectors.sparse(6, Array(0, 2), Array(1, 1)), + Vectors.sparse(6, Array(3, 4), Array(1, 1)), + Vectors.sparse(6, Array(3, 5), Array(1, 1)), + Vectors.sparse(6, Array(4, 5), Array(1, 1)) + ).zipWithIndex.map { case (wordCounts, docId) => (docId.toLong, wordCounts) } + val docs = sc.parallelize(toydata) + + val ldaModel: LocalLDAModel = new LocalLDAModel( + topics, Vectors.dense(Array.fill(k)(alpha)), eta, gammaShape) + + /* Verify results using gensim: + import numpy as np + from gensim import models + corpus = [ + [(0, 1.0), (1, 1.0)], + [(1, 1.0), (2, 1.0)], + [(0, 1.0), (2, 1.0)], + [(3, 1.0), (4, 1.0)], + [(3, 1.0), (5, 1.0)], + [(4, 1.0), (5, 1.0)]] + np.random.seed(2345) + lda = models.ldamodel.LdaModel( + corpus=corpus, alpha=0.01, eta=0.01, num_topics=2, update_every=0, passes=100, + decay=0.51, offset=1024) + print(list(lda.get_document_topics(corpus))) + > [[(0, 0.99504950495049516)], [(0, 0.99504950495049516)], + > [(0, 0.99504950495049516)], [(1, 0.99504950495049516)], + > [(1, 0.99504950495049516)], [(1, 0.99504950495049516)]] + */ + + val expectedPredictions = List( + (0, 0.99504), (0, 0.99504), + (0, 0.99504), (1, 0.99504), + (1, 0.99504), (1, 0.99504)) + + val actualPredictions = ldaModel.topicDistributions(docs).map { case (id, topics) => + // convert results to expectedPredictions format, which only has highest probability topic + val topicsBz = topics.toBreeze.toDenseVector + (id, (argmax(topicsBz), max(topicsBz))) + }.sortByKey() + .values + .collect() + + expectedPredictions.zip(actualPredictions).forall { case (expected, actual) => + expected._1 === actual._1 && (expected._2 ~== actual._2 relTol 1E-3D) + } + } + test("OnlineLDAOptimizer with asymmetric prior") { def toydata: Array[(Long, Vector)] = Array( Vectors.sparse(6, Array(0, 1), Array(1, 1)), @@ -287,7 +391,8 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext { test("model save/load") { // Test for LocalLDAModel. - val localModel = new LocalLDAModel(tinyTopics) + val localModel = new LocalLDAModel(tinyTopics, + Vectors.dense(Array.fill(tinyTopics.numRows)(0.01)), 0.5D, 10D) val tempDir1 = Utils.createTempDir() val path1 = tempDir1.toURI.toString @@ -313,6 +418,9 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext { assert(samelocalModel.topicsMatrix === localModel.topicsMatrix) assert(samelocalModel.k === localModel.k) assert(samelocalModel.vocabSize === localModel.vocabSize) + assert(samelocalModel.docConcentration === localModel.docConcentration) + assert(samelocalModel.topicConcentration === localModel.topicConcentration) + assert(samelocalModel.gammaShape === localModel.gammaShape) val sameDistributedModel = DistributedLDAModel.load(sc, path2) assert(distributedModel.topicsMatrix === sameDistributedModel.topicsMatrix) @@ -321,6 +429,7 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext { assert(distributedModel.iterationTimes === sameDistributedModel.iterationTimes) assert(distributedModel.docConcentration === sameDistributedModel.docConcentration) assert(distributedModel.topicConcentration === sameDistributedModel.topicConcentration) + assert(distributedModel.gammaShape === sameDistributedModel.gammaShape) assert(distributedModel.globalTopicTotals === sameDistributedModel.globalTopicTotals) val graph = distributedModel.graph @@ -339,6 +448,46 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext { } } + test("EMLDAOptimizer with empty docs") { + val vocabSize = 6 + val emptyDocsArray = Array.fill(6)(Vectors.sparse(vocabSize, Array.empty, Array.empty)) + val emptyDocs = emptyDocsArray + .zipWithIndex.map { case (wordCounts, docId) => + (docId.toLong, wordCounts) + } + val distributedEmptyDocs = sc.parallelize(emptyDocs, 2) + + val op = new EMLDAOptimizer() + val lda = new LDA() + .setK(3) + .setMaxIterations(5) + .setSeed(12345) + .setOptimizer(op) + + val model = lda.run(distributedEmptyDocs) + assert(model.vocabSize === vocabSize) + } + + test("OnlineLDAOptimizer with empty docs") { + val vocabSize = 6 + val emptyDocsArray = Array.fill(6)(Vectors.sparse(vocabSize, Array.empty, Array.empty)) + val emptyDocs = emptyDocsArray + .zipWithIndex.map { case (wordCounts, docId) => + (docId.toLong, wordCounts) + } + val distributedEmptyDocs = sc.parallelize(emptyDocs, 2) + + val op = new OnlineLDAOptimizer() + val lda = new LDA() + .setK(3) + .setMaxIterations(5) + .setSeed(12345) + .setOptimizer(op) + + val model = lda.run(distributedEmptyDocs) + assert(model.vocabSize === vocabSize) + } + } private[clustering] object LDASuite { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala index ac01622b8a089..3645d29dccdb2 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.mllib.clustering import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.TestingUtils._ -import org.apache.spark.streaming.TestSuiteBase +import org.apache.spark.streaming.{StreamingContext, TestSuiteBase} import org.apache.spark.streaming.dstream.DStream import org.apache.spark.util.random.XORShiftRandom @@ -28,6 +28,15 @@ class StreamingKMeansSuite extends SparkFunSuite with TestSuiteBase { override def maxWaitTimeMillis: Int = 30000 + var ssc: StreamingContext = _ + + override def afterFunction() { + super.afterFunction() + if (ssc != null) { + ssc.stop() + } + } + test("accuracy for single center and equivalence to grand average") { // set parameters val numBatches = 10 @@ -46,7 +55,7 @@ class StreamingKMeansSuite extends SparkFunSuite with TestSuiteBase { val (input, centers) = StreamingKMeansDataGenerator(numPoints, numBatches, k, d, r, 42) // setup and run the model training - val ssc = setupStreams(input, (inputDStream: DStream[Vector]) => { + ssc = setupStreams(input, (inputDStream: DStream[Vector]) => { model.trainOn(inputDStream) inputDStream.count() }) @@ -82,7 +91,7 @@ class StreamingKMeansSuite extends SparkFunSuite with TestSuiteBase { val (input, centers) = StreamingKMeansDataGenerator(numPoints, numBatches, k, d, r, 42) // setup and run the model training - val ssc = setupStreams(input, (inputDStream: DStream[Vector]) => { + ssc = setupStreams(input, (inputDStream: DStream[Vector]) => { kMeans.trainOn(inputDStream) inputDStream.count() }) @@ -114,7 +123,7 @@ class StreamingKMeansSuite extends SparkFunSuite with TestSuiteBase { StreamingKMeansDataGenerator(numPoints, numBatches, k, d, r, 42, Array(Vectors.dense(0.0))) // setup and run the model training - val ssc = setupStreams(input, (inputDStream: DStream[Vector]) => { + ssc = setupStreams(input, (inputDStream: DStream[Vector]) => { kMeans.trainOn(inputDStream) inputDStream.count() }) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala index 9f107c89f6d80..6dd2dc926acc5 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala @@ -44,13 +44,6 @@ class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext { val rdd = sc.parallelize(sequences, 2).cache() - def compareResult( - expectedValue: Array[(Array[Int], Long)], - actualValue: Array[(Array[Int], Long)]): Boolean = { - expectedValue.map(x => (x._1.toSeq, x._2)).toSet == - actualValue.map(x => (x._1.toSeq, x._2)).toSet - } - val prefixspan = new PrefixSpan() .setMinSupport(0.33) .setMaxPatternLength(50) @@ -76,7 +69,7 @@ class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext { (Array(4, 5), 2L), (Array(5), 3L) ) - assert(compareResult(expectedValue1, result1.collect())) + assert(compareResults(expectedValue1, result1.collect())) prefixspan.setMinSupport(0.5).setMaxPatternLength(50) val result2 = prefixspan.run(rdd) @@ -87,7 +80,7 @@ class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext { (Array(4), 4L), (Array(5), 3L) ) - assert(compareResult(expectedValue2, result2.collect())) + assert(compareResults(expectedValue2, result2.collect())) prefixspan.setMinSupport(0.33).setMaxPatternLength(2) val result3 = prefixspan.run(rdd) @@ -107,6 +100,14 @@ class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext { (Array(4, 5), 2L), (Array(5), 3L) ) - assert(compareResult(expectedValue3, result3.collect())) + assert(compareResults(expectedValue3, result3.collect())) + } + + private def compareResults( + expectedValue: Array[(Array[Int], Long)], + actualValue: Array[(Array[Int], Long)]): Boolean = { + expectedValue.map(x => (x._1.toSeq, x._2)).toSet == + actualValue.map(x => (x._1.toSeq, x._2)).toSet } + } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala index d34888af2d73b..e331c75989187 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala @@ -30,20 +30,20 @@ class PeriodicGraphCheckpointerSuite extends SparkFunSuite with MLlibTestSparkCo import PeriodicGraphCheckpointerSuite._ - // TODO: Do I need to call count() on the graphs' RDDs? - test("Persisting") { var graphsToCheck = Seq.empty[GraphToCheck] val graph1 = createGraph(sc) - val checkpointer = new PeriodicGraphCheckpointer(graph1, 10) + val checkpointer = + new PeriodicGraphCheckpointer[Double, Double](10, graph1.vertices.sparkContext) + checkpointer.update(graph1) graphsToCheck = graphsToCheck :+ GraphToCheck(graph1, 1) checkPersistence(graphsToCheck, 1) var iteration = 2 while (iteration < 9) { val graph = createGraph(sc) - checkpointer.updateGraph(graph) + checkpointer.update(graph) graphsToCheck = graphsToCheck :+ GraphToCheck(graph, iteration) checkPersistence(graphsToCheck, iteration) iteration += 1 @@ -57,7 +57,9 @@ class PeriodicGraphCheckpointerSuite extends SparkFunSuite with MLlibTestSparkCo var graphsToCheck = Seq.empty[GraphToCheck] sc.setCheckpointDir(path) val graph1 = createGraph(sc) - val checkpointer = new PeriodicGraphCheckpointer(graph1, checkpointInterval) + val checkpointer = new PeriodicGraphCheckpointer[Double, Double]( + checkpointInterval, graph1.vertices.sparkContext) + checkpointer.update(graph1) graph1.edges.count() graph1.vertices.count() graphsToCheck = graphsToCheck :+ GraphToCheck(graph1, 1) @@ -66,7 +68,7 @@ class PeriodicGraphCheckpointerSuite extends SparkFunSuite with MLlibTestSparkCo var iteration = 2 while (iteration < 9) { val graph = createGraph(sc) - checkpointer.updateGraph(graph) + checkpointer.update(graph) graph.vertices.count() graph.edges.count() graphsToCheck = graphsToCheck :+ GraphToCheck(graph, iteration) @@ -168,7 +170,7 @@ private object PeriodicGraphCheckpointerSuite { } else { // Graph should never be checkpointed assert(!graph.isCheckpointed, "Graph should never have been checkpointed") - assert(graph.getCheckpointFiles.length == 0, "Graph should not have any checkpoint files") + assert(graph.getCheckpointFiles.isEmpty, "Graph should not have any checkpoint files") } } catch { case e: AssertionError => diff --git a/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointerSuite.scala new file mode 100644 index 0000000000000..b2a459a68b5fa --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointerSuite.scala @@ -0,0 +1,173 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.impl + +import org.apache.hadoop.fs.{FileSystem, Path} + +import org.apache.spark.{SparkContext, SparkFunSuite} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.rdd.RDD +import org.apache.spark.storage.StorageLevel +import org.apache.spark.util.Utils + + +class PeriodicRDDCheckpointerSuite extends SparkFunSuite with MLlibTestSparkContext { + + import PeriodicRDDCheckpointerSuite._ + + test("Persisting") { + var rddsToCheck = Seq.empty[RDDToCheck] + + val rdd1 = createRDD(sc) + val checkpointer = new PeriodicRDDCheckpointer[Double](10, rdd1.sparkContext) + checkpointer.update(rdd1) + rddsToCheck = rddsToCheck :+ RDDToCheck(rdd1, 1) + checkPersistence(rddsToCheck, 1) + + var iteration = 2 + while (iteration < 9) { + val rdd = createRDD(sc) + checkpointer.update(rdd) + rddsToCheck = rddsToCheck :+ RDDToCheck(rdd, iteration) + checkPersistence(rddsToCheck, iteration) + iteration += 1 + } + } + + test("Checkpointing") { + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + val checkpointInterval = 2 + var rddsToCheck = Seq.empty[RDDToCheck] + sc.setCheckpointDir(path) + val rdd1 = createRDD(sc) + val checkpointer = new PeriodicRDDCheckpointer[Double](checkpointInterval, rdd1.sparkContext) + checkpointer.update(rdd1) + rdd1.count() + rddsToCheck = rddsToCheck :+ RDDToCheck(rdd1, 1) + checkCheckpoint(rddsToCheck, 1, checkpointInterval) + + var iteration = 2 + while (iteration < 9) { + val rdd = createRDD(sc) + checkpointer.update(rdd) + rdd.count() + rddsToCheck = rddsToCheck :+ RDDToCheck(rdd, iteration) + checkCheckpoint(rddsToCheck, iteration, checkpointInterval) + iteration += 1 + } + + checkpointer.deleteAllCheckpoints() + rddsToCheck.foreach { rdd => + confirmCheckpointRemoved(rdd.rdd) + } + + Utils.deleteRecursively(tempDir) + } +} + +private object PeriodicRDDCheckpointerSuite { + + case class RDDToCheck(rdd: RDD[Double], gIndex: Int) + + def createRDD(sc: SparkContext): RDD[Double] = { + sc.parallelize(Seq(0.0, 1.0, 2.0, 3.0)) + } + + def checkPersistence(rdds: Seq[RDDToCheck], iteration: Int): Unit = { + rdds.foreach { g => + checkPersistence(g.rdd, g.gIndex, iteration) + } + } + + /** + * Check storage level of rdd. + * @param gIndex Index of rdd in order inserted into checkpointer (from 1). + * @param iteration Total number of rdds inserted into checkpointer. + */ + def checkPersistence(rdd: RDD[_], gIndex: Int, iteration: Int): Unit = { + try { + if (gIndex + 2 < iteration) { + assert(rdd.getStorageLevel == StorageLevel.NONE) + } else { + assert(rdd.getStorageLevel != StorageLevel.NONE) + } + } catch { + case _: AssertionError => + throw new Exception(s"PeriodicRDDCheckpointerSuite.checkPersistence failed with:\n" + + s"\t gIndex = $gIndex\n" + + s"\t iteration = $iteration\n" + + s"\t rdd.getStorageLevel = ${rdd.getStorageLevel}\n") + } + } + + def checkCheckpoint(rdds: Seq[RDDToCheck], iteration: Int, checkpointInterval: Int): Unit = { + rdds.reverse.foreach { g => + checkCheckpoint(g.rdd, g.gIndex, iteration, checkpointInterval) + } + } + + def confirmCheckpointRemoved(rdd: RDD[_]): Unit = { + // Note: We cannot check rdd.isCheckpointed since that value is never updated. + // Instead, we check for the presence of the checkpoint files. + // This test should continue to work even after this rdd.isCheckpointed issue + // is fixed (though it can then be simplified and not look for the files). + val fs = FileSystem.get(rdd.sparkContext.hadoopConfiguration) + rdd.getCheckpointFile.foreach { checkpointFile => + assert(!fs.exists(new Path(checkpointFile)), "RDD checkpoint file should have been removed") + } + } + + /** + * Check checkpointed status of rdd. + * @param gIndex Index of rdd in order inserted into checkpointer (from 1). + * @param iteration Total number of rdds inserted into checkpointer. + */ + def checkCheckpoint( + rdd: RDD[_], + gIndex: Int, + iteration: Int, + checkpointInterval: Int): Unit = { + try { + if (gIndex % checkpointInterval == 0) { + // We allow 2 checkpoint intervals since we perform an action (checkpointing a second rdd) + // only AFTER PeriodicRDDCheckpointer decides whether to remove the previous checkpoint. + if (iteration - 2 * checkpointInterval < gIndex && gIndex <= iteration) { + assert(rdd.isCheckpointed, "RDD should be checkpointed") + assert(rdd.getCheckpointFile.nonEmpty, "RDD should have 2 checkpoint files") + } else { + confirmCheckpointRemoved(rdd) + } + } else { + // RDD should never be checkpointed + assert(!rdd.isCheckpointed, "RDD should never have been checkpointed") + assert(rdd.getCheckpointFile.isEmpty, "RDD should not have any checkpoint files") + } + } catch { + case e: AssertionError => + throw new Exception(s"PeriodicRDDCheckpointerSuite.checkCheckpoint failed with:\n" + + s"\t gIndex = $gIndex\n" + + s"\t iteration = $iteration\n" + + s"\t checkpointInterval = $checkpointInterval\n" + + s"\t rdd.isCheckpointed = ${rdd.isCheckpointed}\n" + + s"\t rdd.getCheckpointFile = ${rdd.getCheckpointFile.mkString(", ")}\n" + + s" AssertionError message: ${e.getMessage}") + } + } + +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala index 03be4119bdaca..1c37ea5123e82 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala @@ -57,6 +57,21 @@ class VectorsSuite extends SparkFunSuite with Logging { assert(vec.values === values) } + test("sparse vector construction with mismatched indices/values array") { + intercept[IllegalArgumentException] { + Vectors.sparse(4, Array(1, 2, 3), Array(3.0, 5.0, 7.0, 9.0)) + } + intercept[IllegalArgumentException] { + Vectors.sparse(4, Array(1, 2, 3), Array(3.0, 5.0)) + } + } + + test("sparse vector construction with too many indices vs size") { + intercept[IllegalArgumentException] { + Vectors.sparse(3, Array(1, 2, 3, 4), Array(3.0, 5.0, 7.0, 9.0)) + } + } + test("dense to array") { val vec = Vectors.dense(arr).asInstanceOf[DenseVector] assert(vec.toArray.eq(arr)) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala index b6cb53d0c743e..283ffec1d49d7 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.mllib.linalg.distributed import scala.util.Random +import breeze.numerics.abs import breeze.linalg.{DenseVector => BDV, DenseMatrix => BDM, norm => brzNorm, svd => brzSvd} import org.apache.spark.SparkFunSuite @@ -238,6 +239,22 @@ class RowMatrixSuite extends SparkFunSuite with MLlibTestSparkContext { } } } + + test("QR Decomposition") { + for (mat <- Seq(denseMat, sparseMat)) { + val result = mat.tallSkinnyQR(true) + val expected = breeze.linalg.qr.reduced(mat.toBreeze()) + val calcQ = result.Q + val calcR = result.R + assert(closeToZero(abs(expected.q) - abs(calcQ.toBreeze()))) + assert(closeToZero(abs(expected.r) - abs(calcR.toBreeze.asInstanceOf[BDM[Double]]))) + assert(closeToZero(calcQ.multiply(calcR).toBreeze - mat.toBreeze())) + // Decomposition without computing Q + val rOnly = mat.tallSkinnyQR(computeQ = false) + assert(rOnly.Q == null) + assert(closeToZero(abs(expected.r) - abs(rOnly.R.toBreeze.asInstanceOf[BDM[Double]]))) + } + } } class RowMatrixClusterSuite extends SparkFunSuite with LocalClusterSparkContext { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala index a2a4c5f6b8b70..34c07ed170816 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala @@ -22,14 +22,23 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.LinearDataGenerator +import org.apache.spark.streaming.{StreamingContext, TestSuiteBase} import org.apache.spark.streaming.dstream.DStream -import org.apache.spark.streaming.TestSuiteBase class StreamingLinearRegressionSuite extends SparkFunSuite with TestSuiteBase { // use longer wait time to ensure job completion override def maxWaitTimeMillis: Int = 20000 + var ssc: StreamingContext = _ + + override def afterFunction() { + super.afterFunction() + if (ssc != null) { + ssc.stop() + } + } + // Assert that two values are equal within tolerance epsilon def assertEqual(v1: Double, v2: Double, epsilon: Double) { def errorMessage = v1.toString + " did not equal " + v2.toString @@ -62,7 +71,7 @@ class StreamingLinearRegressionSuite extends SparkFunSuite with TestSuiteBase { } // apply model training to input stream - val ssc = setupStreams(input, (inputDStream: DStream[LabeledPoint]) => { + ssc = setupStreams(input, (inputDStream: DStream[LabeledPoint]) => { model.trainOn(inputDStream) inputDStream.count() }) @@ -98,7 +107,7 @@ class StreamingLinearRegressionSuite extends SparkFunSuite with TestSuiteBase { // apply model training to input stream, storing the intermediate results // (we add a count to ensure the result is a DStream) - val ssc = setupStreams(input, (inputDStream: DStream[LabeledPoint]) => { + ssc = setupStreams(input, (inputDStream: DStream[LabeledPoint]) => { model.trainOn(inputDStream) inputDStream.foreachRDD(x => history.append(math.abs(model.latestModel().weights(0) - 10.0))) inputDStream.count() @@ -129,7 +138,7 @@ class StreamingLinearRegressionSuite extends SparkFunSuite with TestSuiteBase { } // apply model predictions to test stream - val ssc = setupStreams(testInput, (inputDStream: DStream[LabeledPoint]) => { + ssc = setupStreams(testInput, (inputDStream: DStream[LabeledPoint]) => { model.predictOnValues(inputDStream.map(x => (x.label, x.features))) }) // collect the output as (true, estimated) tuples @@ -156,7 +165,7 @@ class StreamingLinearRegressionSuite extends SparkFunSuite with TestSuiteBase { } // train and predict - val ssc = setupStreams(testInput, (inputDStream: DStream[LabeledPoint]) => { + ssc = setupStreams(testInput, (inputDStream: DStream[LabeledPoint]) => { model.trainOn(inputDStream) model.predictOnValues(inputDStream.map(x => (x.label, x.features))) }) @@ -177,7 +186,7 @@ class StreamingLinearRegressionSuite extends SparkFunSuite with TestSuiteBase { val numBatches = 10 val nPoints = 100 val emptyInput = Seq.empty[Seq[LabeledPoint]] - val ssc = setupStreams(emptyInput, + ssc = setupStreams(emptyInput, (inputDStream: DStream[LabeledPoint]) => { model.trainOn(inputDStream) model.predictOnValues(inputDStream.map(x => (x.label, x.features))) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala index 2521b3342181a..6fc9e8df621df 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala @@ -166,43 +166,58 @@ class GradientBoostedTreesSuite extends SparkFunSuite with MLlibTestSparkContext val algos = Array(Regression, Regression, Classification) val losses = Array(SquaredError, AbsoluteError, LogLoss) - (algos zip losses) map { - case (algo, loss) => { - val treeStrategy = new Strategy(algo = algo, impurity = Variance, maxDepth = 2, - categoricalFeaturesInfo = Map.empty) - val boostingStrategy = - new BoostingStrategy(treeStrategy, loss, numIterations, validationTol = 0.0) - val gbtValidate = new GradientBoostedTrees(boostingStrategy) - .runWithValidation(trainRdd, validateRdd) - val numTrees = gbtValidate.numTrees - assert(numTrees !== numIterations) - - // Test that it performs better on the validation dataset. - val gbt = new GradientBoostedTrees(boostingStrategy).run(trainRdd) - val (errorWithoutValidation, errorWithValidation) = { - if (algo == Classification) { - val remappedRdd = validateRdd.map(x => new LabeledPoint(2 * x.label - 1, x.features)) - (loss.computeError(gbt, remappedRdd), loss.computeError(gbtValidate, remappedRdd)) - } else { - (loss.computeError(gbt, validateRdd), loss.computeError(gbtValidate, validateRdd)) - } - } - assert(errorWithValidation <= errorWithoutValidation) - - // Test that results from evaluateEachIteration comply with runWithValidation. - // Note that convergenceTol is set to 0.0 - val evaluationArray = gbt.evaluateEachIteration(validateRdd, loss) - assert(evaluationArray.length === numIterations) - assert(evaluationArray(numTrees) > evaluationArray(numTrees - 1)) - var i = 1 - while (i < numTrees) { - assert(evaluationArray(i) <= evaluationArray(i - 1)) - i += 1 + algos.zip(losses).foreach { case (algo, loss) => + val treeStrategy = new Strategy(algo = algo, impurity = Variance, maxDepth = 2, + categoricalFeaturesInfo = Map.empty) + val boostingStrategy = + new BoostingStrategy(treeStrategy, loss, numIterations, validationTol = 0.0) + val gbtValidate = new GradientBoostedTrees(boostingStrategy) + .runWithValidation(trainRdd, validateRdd) + val numTrees = gbtValidate.numTrees + assert(numTrees !== numIterations) + + // Test that it performs better on the validation dataset. + val gbt = new GradientBoostedTrees(boostingStrategy).run(trainRdd) + val (errorWithoutValidation, errorWithValidation) = { + if (algo == Classification) { + val remappedRdd = validateRdd.map(x => new LabeledPoint(2 * x.label - 1, x.features)) + (loss.computeError(gbt, remappedRdd), loss.computeError(gbtValidate, remappedRdd)) + } else { + (loss.computeError(gbt, validateRdd), loss.computeError(gbtValidate, validateRdd)) } } + assert(errorWithValidation <= errorWithoutValidation) + + // Test that results from evaluateEachIteration comply with runWithValidation. + // Note that convergenceTol is set to 0.0 + val evaluationArray = gbt.evaluateEachIteration(validateRdd, loss) + assert(evaluationArray.length === numIterations) + assert(evaluationArray(numTrees) > evaluationArray(numTrees - 1)) + var i = 1 + while (i < numTrees) { + assert(evaluationArray(i) <= evaluationArray(i - 1)) + i += 1 + } } } + test("Checkpointing") { + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + sc.setCheckpointDir(path) + + val rdd = sc.parallelize(GradientBoostedTreesSuite.data, 2) + + val treeStrategy = new Strategy(algo = Regression, impurity = Variance, maxDepth = 2, + categoricalFeaturesInfo = Map.empty, checkpointInterval = 2) + val boostingStrategy = new BoostingStrategy(treeStrategy, SquaredError, 5, 0.1) + + val gbt = GradientBoostedTrees.train(rdd, boostingStrategy) + + sc.checkpointDir = None + Utils.deleteRecursively(tempDir) + } + } private object GradientBoostedTreesSuite { diff --git a/pylintrc b/pylintrc index 061775960393b..6a675770da69a 100644 --- a/pylintrc +++ b/pylintrc @@ -84,7 +84,7 @@ enable= # If you would like to improve the code quality of pyspark, remove any of these disabled errors # run ./dev/lint-python and see if the errors raised by pylint can be fixed. -disable=invalid-name,missing-docstring,protected-access,unused-argument,no-member,unused-wildcard-import,redefined-builtin,too-many-arguments,unused-variable,too-few-public-methods,bad-continuation,duplicate-code,redefined-outer-name,too-many-ancestors,import-error,superfluous-parens,unused-import,line-too-long,no-name-in-module,unnecessary-lambda,import-self,no-self-use,unidiomatic-typecheck,fixme,too-many-locals,cyclic-import,too-many-branches,bare-except,wildcard-import,dangerous-default-value,broad-except,too-many-public-methods,deprecated-lambda,anomalous-backslash-in-string,too-many-lines,reimported,too-many-statements,bad-whitespace,unpacking-non-sequence,too-many-instance-attributes,abstract-method,old-style-class,global-statement,attribute-defined-outside-init,arguments-differ,undefined-all-variable,no-init,useless-else-on-loop,super-init-not-called,notimplemented-raised,too-many-return-statements,pointless-string-statement,global-variable-undefined,bad-classmethod-argument,too-many-format-args,parse-error,no-self-argument,pointless-statement,undefined-variable +disable=invalid-name,missing-docstring,protected-access,unused-argument,no-member,unused-wildcard-import,redefined-builtin,too-many-arguments,unused-variable,too-few-public-methods,bad-continuation,duplicate-code,redefined-outer-name,too-many-ancestors,import-error,superfluous-parens,unused-import,line-too-long,no-name-in-module,unnecessary-lambda,import-self,no-self-use,unidiomatic-typecheck,fixme,too-many-locals,cyclic-import,too-many-branches,bare-except,wildcard-import,dangerous-default-value,broad-except,too-many-public-methods,deprecated-lambda,anomalous-backslash-in-string,too-many-lines,reimported,too-many-statements,bad-whitespace,unpacking-non-sequence,too-many-instance-attributes,abstract-method,old-style-class,global-statement,attribute-defined-outside-init,arguments-differ,undefined-all-variable,no-init,useless-else-on-loop,super-init-not-called,notimplemented-raised,too-many-return-statements,pointless-string-statement,global-variable-undefined,bad-classmethod-argument,too-many-format-args,parse-error,no-self-argument,pointless-statement,undefined-variable,undefined-loop-variable [REPORTS] diff --git a/python/pyspark/cloudpickle.py b/python/pyspark/cloudpickle.py index 9ef93071d2e77..3b647985801b7 100644 --- a/python/pyspark/cloudpickle.py +++ b/python/pyspark/cloudpickle.py @@ -350,7 +350,26 @@ def save_global(self, obj, name=None, pack=struct.pack): if new_override: d['__new__'] = obj.__new__ - self.save_reduce(typ, (obj.__name__, obj.__bases__, d), obj=obj) + self.save(_load_class) + self.save_reduce(typ, (obj.__name__, obj.__bases__, {"__doc__": obj.__doc__}), obj=obj) + d.pop('__doc__', None) + # handle property and staticmethod + dd = {} + for k, v in d.items(): + if isinstance(v, property): + k = ('property', k) + v = (v.fget, v.fset, v.fdel, v.__doc__) + elif isinstance(v, staticmethod) and hasattr(v, '__func__'): + k = ('staticmethod', k) + v = v.__func__ + elif isinstance(v, classmethod) and hasattr(v, '__func__'): + k = ('classmethod', k) + v = v.__func__ + dd[k] = v + self.save(dd) + self.write(pickle.TUPLE2) + self.write(pickle.REDUCE) + else: raise pickle.PicklingError("Can't pickle %r" % obj) @@ -708,6 +727,23 @@ def _make_skel_func(code, closures, base_globals = None): None, None, closure) +def _load_class(cls, d): + """ + Loads additional properties into class `cls`. + """ + for k, v in d.items(): + if isinstance(k, tuple): + typ, k = k + if typ == 'property': + v = property(*v) + elif typ == 'staticmethod': + v = staticmethod(v) + elif typ == 'classmethod': + v = classmethod(v) + setattr(cls, k, v) + return cls + + """Constructors for 3rd party libraries Note: These can never be renamed due to client compatibility issues""" diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py index 90cd342a6cf7f..60be85e53e2aa 100644 --- a/python/pyspark/java_gateway.py +++ b/python/pyspark/java_gateway.py @@ -52,7 +52,11 @@ def launch_gateway(): script = "./bin/spark-submit.cmd" if on_windows else "./bin/spark-submit" submit_args = os.environ.get("PYSPARK_SUBMIT_ARGS", "pyspark-shell") if os.environ.get("SPARK_TESTING"): - submit_args = "--conf spark.ui.enabled=false " + submit_args + submit_args = ' '.join([ + "--conf spark.ui.enabled=false", + "--conf spark.buffer.pageSize=4mb", + submit_args + ]) command = [os.path.join(SPARK_HOME, script)] + shlex.split(submit_args) # Start a socket that will be used by PythonGatewayServer to communicate its port to us diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 89117e492846b..5a82bc286d1e8 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -299,9 +299,9 @@ class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred >>> stringIndexer = StringIndexer(inputCol="label", outputCol="indexed") >>> si_model = stringIndexer.fit(df) >>> td = si_model.transform(df) - >>> rf = RandomForestClassifier(numTrees=2, maxDepth=2, labelCol="indexed", seed=42) + >>> rf = RandomForestClassifier(numTrees=3, maxDepth=2, labelCol="indexed", seed=42) >>> model = rf.fit(td) - >>> allclose(model.treeWeights, [1.0, 1.0]) + >>> allclose(model.treeWeights, [1.0, 1.0, 1.0]) True >>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"]) >>> model.transform(test0).head().prediction diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 86e654dd0779f..015e7a9d4900a 100644 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -525,7 +525,7 @@ class RegexTokenizer(JavaTransformer, HasInputCol, HasOutputCol): """ A regex based tokenizer that extracts tokens either by using the provided regex pattern (in Java dialect) to split the text - (default) or repeatedly matching the regex (if gaps is true). + (default) or repeatedly matching the regex (if gaps is false). Optional parameters also allow filtering tokens using a minimal length. It returns an array of strings that can be empty. diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py index 58ad99d46e23b..900ade248c386 100644 --- a/python/pyspark/mllib/clustering.py +++ b/python/pyspark/mllib/clustering.py @@ -152,11 +152,19 @@ def train(cls, rdd, k, maxIterations=100, runs=1, initializationMode="k-means||" return KMeansModel([c.toArray() for c in centers]) -class GaussianMixtureModel(object): +@inherit_doc +class GaussianMixtureModel(JavaModelWrapper, JavaSaveable, JavaLoader): + + """ + .. note:: Experimental - """A clustering model derived from the Gaussian Mixture Model method. + A clustering model derived from the Gaussian Mixture Model method. >>> from pyspark.mllib.linalg import Vectors, DenseMatrix + >>> from numpy.testing import assert_equal + >>> from shutil import rmtree + >>> import os, tempfile + >>> clusterdata_1 = sc.parallelize(array([-0.1,-0.05,-0.01,-0.1, ... 0.9,0.8,0.75,0.935, ... -0.83,-0.68,-0.91,-0.76 ]).reshape(6, 2)) @@ -169,6 +177,25 @@ class GaussianMixtureModel(object): True >>> labels[4]==labels[5] True + + >>> path = tempfile.mkdtemp() + >>> model.save(sc, path) + >>> sameModel = GaussianMixtureModel.load(sc, path) + >>> assert_equal(model.weights, sameModel.weights) + >>> mus, sigmas = list( + ... zip(*[(g.mu, g.sigma) for g in model.gaussians])) + >>> sameMus, sameSigmas = list( + ... zip(*[(g.mu, g.sigma) for g in sameModel.gaussians])) + >>> mus == sameMus + True + >>> sigmas == sameSigmas + True + >>> from shutil import rmtree + >>> try: + ... rmtree(path) + ... except OSError: + ... pass + >>> data = array([-5.1971, -2.5359, -3.8220, ... -5.2211, -5.0602, 4.7118, ... 6.8989, 3.4592, 4.6322, @@ -182,25 +209,15 @@ class GaussianMixtureModel(object): True >>> labels[3]==labels[4] True - >>> clusterdata_3 = sc.parallelize(data.reshape(15, 1)) - >>> im = GaussianMixtureModel([0.5, 0.5], - ... [MultivariateGaussian(Vectors.dense([-1.0]), DenseMatrix(1, 1, [1.0])), - ... MultivariateGaussian(Vectors.dense([1.0]), DenseMatrix(1, 1, [1.0]))]) - >>> model = GaussianMixture.train(clusterdata_3, 2, initialModel=im) """ - def __init__(self, weights, gaussians): - self._weights = weights - self._gaussians = gaussians - self._k = len(self._weights) - @property def weights(self): """ Weights for each Gaussian distribution in the mixture, where weights[i] is the weight for Gaussian i, and weights.sum == 1. """ - return self._weights + return array(self.call("weights")) @property def gaussians(self): @@ -208,12 +225,14 @@ def gaussians(self): Array of MultivariateGaussian where gaussians[i] represents the Multivariate Gaussian (Normal) Distribution for Gaussian i. """ - return self._gaussians + return [ + MultivariateGaussian(gaussian[0], gaussian[1]) + for gaussian in zip(*self.call("gaussians"))] @property def k(self): """Number of gaussians in mixture.""" - return self._k + return len(self.weights) def predict(self, x): """ @@ -238,17 +257,30 @@ def predictSoft(self, x): :return: membership_matrix. RDD of array of double values. """ if isinstance(x, RDD): - means, sigmas = zip(*[(g.mu, g.sigma) for g in self._gaussians]) + means, sigmas = zip(*[(g.mu, g.sigma) for g in self.gaussians]) membership_matrix = callMLlibFunc("predictSoftGMM", x.map(_convert_to_vector), - _convert_to_vector(self._weights), means, sigmas) + _convert_to_vector(self.weights), means, sigmas) return membership_matrix.map(lambda x: pyarray.array('d', x)) else: raise TypeError("x should be represented by an RDD, " "but got %s." % type(x)) + @classmethod + def load(cls, sc, path): + """Load the GaussianMixtureModel from disk. + + :param sc: SparkContext + :param path: str, path to where the model is stored. + """ + model = cls._load_java(sc, path) + wrapper = sc._jvm.GaussianMixtureModelWrapper(model) + return cls(wrapper) + class GaussianMixture(object): """ + .. note:: Experimental + Learning algorithm for Gaussian Mixtures using the expectation-maximization algorithm. :param data: RDD of data points @@ -271,11 +303,10 @@ def train(cls, rdd, k, convergenceTol=1e-3, maxIterations=100, seed=None, initia initialModelWeights = initialModel.weights initialModelMu = [initialModel.gaussians[i].mu for i in range(initialModel.k)] initialModelSigma = [initialModel.gaussians[i].sigma for i in range(initialModel.k)] - weight, mu, sigma = callMLlibFunc("trainGaussianMixtureModel", rdd.map(_convert_to_vector), - k, convergenceTol, maxIterations, seed, - initialModelWeights, initialModelMu, initialModelSigma) - mvg_obj = [MultivariateGaussian(mu[i], sigma[i]) for i in range(k)] - return GaussianMixtureModel(weight, mvg_obj) + java_model = callMLlibFunc("trainGaussianMixtureModel", rdd.map(_convert_to_vector), + k, convergenceTol, maxIterations, seed, + initialModelWeights, initialModelMu, initialModelSigma) + return GaussianMixtureModel(java_model) class PowerIterationClusteringModel(JavaModelWrapper, JavaSaveable, JavaLoader): diff --git a/python/pyspark/mllib/linalg.py b/python/pyspark/mllib/linalg/__init__.py similarity index 100% rename from python/pyspark/mllib/linalg.py rename to python/pyspark/mllib/linalg/__init__.py diff --git a/python/pyspark/mllib/util.py b/python/pyspark/mllib/util.py index 875d3b2d642c6..916de2d6fcdbd 100644 --- a/python/pyspark/mllib/util.py +++ b/python/pyspark/mllib/util.py @@ -21,7 +21,9 @@ if sys.version > '3': xrange = range + basestring = str +from pyspark import SparkContext from pyspark.mllib.common import callMLlibFunc, inherit_doc from pyspark.mllib.linalg import Vectors, SparseVector, _convert_to_vector @@ -223,6 +225,10 @@ class JavaSaveable(Saveable): """ def save(self, sc, path): + if not isinstance(sc, SparkContext): + raise TypeError("sc should be a SparkContext, got type %s" % type(sc)) + if not isinstance(path, basestring): + raise TypeError("path should be a basestring, got type %s" % type(path)) self._java_model.save(sc._jsc.sc(), path) diff --git a/python/pyspark/shuffle.py b/python/pyspark/shuffle.py index 8fb71bac64a5e..b8118bdb7ca76 100644 --- a/python/pyspark/shuffle.py +++ b/python/pyspark/shuffle.py @@ -606,7 +606,7 @@ def _open_file(self): if not os.path.exists(d): os.makedirs(d) p = os.path.join(d, str(id(self))) - self._file = open(p, "wb+", 65536) + self._file = open(p, "w+b", 65536) self._ser = BatchedSerializer(CompressedSerializer(PickleSerializer()), 1024) os.unlink(p) diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index abb6522dde7b0..917de24f3536b 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -277,6 +277,66 @@ def applySchema(self, rdd, schema): return self.createDataFrame(rdd, schema) + def _createFromRDD(self, rdd, schema, samplingRatio): + """ + Create an RDD for DataFrame from an existing RDD, returns the RDD and schema. + """ + if schema is None or isinstance(schema, (list, tuple)): + struct = self._inferSchema(rdd, samplingRatio) + converter = _create_converter(struct) + rdd = rdd.map(converter) + if isinstance(schema, (list, tuple)): + for i, name in enumerate(schema): + struct.fields[i].name = name + struct.names[i] = name + schema = struct + + elif isinstance(schema, StructType): + # take the first few rows to verify schema + rows = rdd.take(10) + for row in rows: + _verify_type(row, schema) + + else: + raise TypeError("schema should be StructType or list or None, but got: %s" % schema) + + # convert python objects to sql data + rdd = rdd.map(schema.toInternal) + return rdd, schema + + def _createFromLocal(self, data, schema): + """ + Create an RDD for DataFrame from an list or pandas.DataFrame, returns + the RDD and schema. + """ + if has_pandas and isinstance(data, pandas.DataFrame): + if schema is None: + schema = [str(x) for x in data.columns] + data = [r.tolist() for r in data.to_records(index=False)] + + # make sure data could consumed multiple times + if not isinstance(data, list): + data = list(data) + + if schema is None or isinstance(schema, (list, tuple)): + struct = self._inferSchemaFromList(data) + if isinstance(schema, (list, tuple)): + for i, name in enumerate(schema): + struct.fields[i].name = name + struct.names[i] = name + schema = struct + + elif isinstance(schema, StructType): + for row in data: + _verify_type(row, schema) + + else: + raise TypeError("schema should be StructType or list or None, but got: %s" % schema) + + # convert python objects to sql data + data = [schema.toInternal(row) for row in data] + return self._sc.parallelize(data), schema + @since(1.3) @ignore_unicode_prefix def createDataFrame(self, data, schema=None, samplingRatio=None): @@ -340,49 +400,15 @@ def createDataFrame(self, data, schema=None, samplingRatio=None): if isinstance(data, DataFrame): raise TypeError("data is already a DataFrame") - if has_pandas and isinstance(data, pandas.DataFrame): - if schema is None: - schema = [str(x) for x in data.columns] - data = [r.tolist() for r in data.to_records(index=False)] - - if not isinstance(data, RDD): - if not isinstance(data, list): - data = list(data) - try: - # data could be list, tuple, generator ... - rdd = self._sc.parallelize(data) - except Exception: - raise TypeError("cannot create an RDD from type: %s" % type(data)) + if isinstance(data, RDD): + rdd, schema = self._createFromRDD(data, schema, samplingRatio) else: - rdd = data - - if schema is None or isinstance(schema, (list, tuple)): - if isinstance(data, RDD): - struct = self._inferSchema(rdd, samplingRatio) - else: - struct = self._inferSchemaFromList(data) - if isinstance(schema, (list, tuple)): - for i, name in enumerate(schema): - struct.fields[i].name = name - schema = struct - converter = _create_converter(schema) - rdd = rdd.map(converter) - - elif isinstance(schema, StructType): - # take the first few rows to verify schema - rows = rdd.take(10) - for row in rows: - _verify_type(row, schema) - - else: - raise TypeError("schema should be StructType or list or None") - - # convert python objects to sql data - rdd = rdd.map(schema.toInternal) - + rdd, schema = self._createFromLocal(data, schema) jrdd = self._jvm.SerDeUtil.toJavaArray(rdd._to_java_object_rdd()) - df = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.json()) - return DataFrame(df, self) + jdf = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.json()) + df = DataFrame(jdf, self) + df._schema = schema + return df @since(1.3) def registerDataFrameAsTable(self, df, tableName): diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index d76e051bd73a1..0f3480c239187 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -441,6 +441,42 @@ def sample(self, withReplacement, fraction, seed=None): rdd = self._jdf.sample(withReplacement, fraction, long(seed)) return DataFrame(rdd, self.sql_ctx) + @since(1.5) + def sampleBy(self, col, fractions, seed=None): + """ + Returns a stratified sample without replacement based on the + fraction given on each stratum. + + :param col: column that defines strata + :param fractions: + sampling fraction for each stratum. If a stratum is not + specified, we treat its fraction as zero. + :param seed: random seed + :return: a new DataFrame that represents the stratified sample + + >>> from pyspark.sql.functions import col + >>> dataset = sqlContext.range(0, 100).select((col("id") % 3).alias("key")) + >>> sampled = dataset.sampleBy("key", fractions={0: 0.1, 1: 0.2}, seed=0) + >>> sampled.groupBy("key").count().orderBy("key").show() + +---+-----+ + |key|count| + +---+-----+ + | 0| 3| + | 1| 8| + +---+-----+ + + """ + if not isinstance(col, str): + raise ValueError("col must be a string, but got %r" % type(col)) + if not isinstance(fractions, dict): + raise ValueError("fractions must be a dict but got %r" % type(fractions)) + for k, v in fractions.items(): + if not isinstance(k, (float, int, long, basestring)): + raise ValueError("key must be float, int, long, or string, but got %r" % type(k)) + fractions[k] = float(v) + seed = seed if seed is not None else random.randint(0, sys.maxsize) + return DataFrame(self._jdf.stat().sampleBy(col, self._jmap(fractions), seed), self.sql_ctx) + @since(1.4) def randomSplit(self, weights, seed=None): """Randomly splits this :class:`DataFrame` with the provided weights. @@ -1314,6 +1350,11 @@ def freqItems(self, cols, support=None): freqItems.__doc__ = DataFrame.freqItems.__doc__ + def sampleBy(self, col, fractions, seed=None): + return self.df.sampleBy(col, fractions, seed) + + sampleBy.__doc__ = DataFrame.sampleBy.__doc__ + def _test(): import doctest diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index d930f7db25d25..8024a8de07c98 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -59,7 +59,7 @@ __all__ += ['lag', 'lead', 'ntile'] __all__ += [ - 'date_format', + 'date_format', 'date_add', 'date_sub', 'add_months', 'months_between', 'year', 'quarter', 'month', 'hour', 'minute', 'second', 'dayofmonth', 'dayofyear', 'weekofyear'] @@ -716,7 +716,7 @@ def date_format(dateCol, format): [Row(date=u'04/08/2015')] """ sc = SparkContext._active_spark_context - return Column(sc._jvm.functions.date_format(dateCol, format)) + return Column(sc._jvm.functions.date_format(_to_java_column(dateCol), format)) @since(1.5) @@ -729,7 +729,7 @@ def year(col): [Row(year=2015)] """ sc = SparkContext._active_spark_context - return Column(sc._jvm.functions.year(col)) + return Column(sc._jvm.functions.year(_to_java_column(col))) @since(1.5) @@ -742,7 +742,7 @@ def quarter(col): [Row(quarter=2)] """ sc = SparkContext._active_spark_context - return Column(sc._jvm.functions.quarter(col)) + return Column(sc._jvm.functions.quarter(_to_java_column(col))) @since(1.5) @@ -755,7 +755,7 @@ def month(col): [Row(month=4)] """ sc = SparkContext._active_spark_context - return Column(sc._jvm.functions.month(col)) + return Column(sc._jvm.functions.month(_to_java_column(col))) @since(1.5) @@ -768,7 +768,7 @@ def dayofmonth(col): [Row(day=8)] """ sc = SparkContext._active_spark_context - return Column(sc._jvm.functions.dayofmonth(col)) + return Column(sc._jvm.functions.dayofmonth(_to_java_column(col))) @since(1.5) @@ -781,7 +781,7 @@ def dayofyear(col): [Row(day=98)] """ sc = SparkContext._active_spark_context - return Column(sc._jvm.functions.dayofyear(col)) + return Column(sc._jvm.functions.dayofyear(_to_java_column(col))) @since(1.5) @@ -794,7 +794,7 @@ def hour(col): [Row(hour=13)] """ sc = SparkContext._active_spark_context - return Column(sc._jvm.functions.hour(col)) + return Column(sc._jvm.functions.hour(_to_java_column(col))) @since(1.5) @@ -807,7 +807,7 @@ def minute(col): [Row(minute=8)] """ sc = SparkContext._active_spark_context - return Column(sc._jvm.functions.minute(col)) + return Column(sc._jvm.functions.minute(_to_java_column(col))) @since(1.5) @@ -820,7 +820,7 @@ def second(col): [Row(second=15)] """ sc = SparkContext._active_spark_context - return Column(sc._jvm.functions.second(col)) + return Column(sc._jvm.functions.second(_to_java_column(col))) @since(1.5) @@ -829,11 +829,93 @@ def weekofyear(col): Extract the week number of a given date as integer. >>> df = sqlContext.createDataFrame([('2015-04-08',)], ['a']) - >>> df.select(weekofyear('a').alias('week')).collect() + >>> df.select(weekofyear(df.a).alias('week')).collect() [Row(week=15)] """ sc = SparkContext._active_spark_context - return Column(sc._jvm.functions.weekofyear(col)) + return Column(sc._jvm.functions.weekofyear(_to_java_column(col))) + + +@since(1.5) +def date_add(start, days): + """ + Returns the date that is `days` days after `start` + + >>> df = sqlContext.createDataFrame([('2015-04-08',)], ['d']) + >>> df.select(date_add(df.d, 1).alias('d')).collect() + [Row(d=datetime.date(2015, 4, 9))] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.date_add(_to_java_column(start), days)) + + +@since(1.5) +def date_sub(start, days): + """ + Returns the date that is `days` days before `start` + + >>> df = sqlContext.createDataFrame([('2015-04-08',)], ['d']) + >>> df.select(date_sub(df.d, 1).alias('d')).collect() + [Row(d=datetime.date(2015, 4, 7))] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.date_sub(_to_java_column(start), days)) + + +@since(1.5) +def add_months(start, months): + """ + Returns the date that is `months` months after `start` + + >>> df = sqlContext.createDataFrame([('2015-04-08',)], ['d']) + >>> df.select(add_months(df.d, 1).alias('d')).collect() + [Row(d=datetime.date(2015, 5, 8))] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.add_months(_to_java_column(start), months)) + + +@since(1.5) +def months_between(date1, date2): + """ + Returns the number of months between date1 and date2. + + >>> df = sqlContext.createDataFrame([('1997-02-28 10:30:00', '1996-10-30')], ['t', 'd']) + >>> df.select(months_between(df.t, df.d).alias('months')).collect() + [Row(months=3.9495967...)] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.months_between(_to_java_column(date1), _to_java_column(date2))) + + +@since(1.5) +def to_date(col): + """ + Converts the column of StringType or TimestampType into DateType. + + >>> df = sqlContext.createDataFrame([('1997-02-28 10:30:00',)], ['t']) + >>> df.select(to_date(df.t).alias('date')).collect() + [Row(date=datetime.date(1997, 2, 28))] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.to_date(_to_java_column(col))) + + +@since(1.5) +def trunc(date, format): + """ + Returns date truncated to the unit specified by the format. + + :param format: 'year', 'YYYY', 'yy' or 'month', 'mon', 'mm' + + >>> df = sqlContext.createDataFrame([('1997-02-28',)], ['d']) + >>> df.select(trunc(df.d, 'year').alias('year')).collect() + [Row(year=datetime.date(1997, 1, 1))] + >>> df.select(trunc(df.d, 'mon').alias('month')).collect() + [Row(month=datetime.date(1997, 2, 1))] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.trunc(_to_java_column(date), format)) @since(1.5) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 5aa6135dc1ee7..ebd3ea8db6a43 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -75,7 +75,7 @@ def sqlType(self): @classmethod def module(cls): - return 'pyspark.tests' + return 'pyspark.sql.tests' @classmethod def scalaUDT(cls): @@ -106,10 +106,45 @@ def __str__(self): return "(%s,%s)" % (self.x, self.y) def __eq__(self, other): - return isinstance(other, ExamplePoint) and \ + return isinstance(other, self.__class__) and \ other.x == self.x and other.y == self.y +class PythonOnlyUDT(UserDefinedType): + """ + User-defined type (UDT) for ExamplePoint. + """ + + @classmethod + def sqlType(self): + return ArrayType(DoubleType(), False) + + @classmethod + def module(cls): + return '__main__' + + def serialize(self, obj): + return [obj.x, obj.y] + + def deserialize(self, datum): + return PythonOnlyPoint(datum[0], datum[1]) + + @staticmethod + def foo(): + pass + + @property + def props(self): + return {} + + +class PythonOnlyPoint(ExamplePoint): + """ + An example class to demonstrate UDT in only Python + """ + __UDT__ = PythonOnlyUDT() + + class DataTypeTests(unittest.TestCase): # regression test for SPARK-6055 def test_data_type_eq(self): @@ -395,10 +430,39 @@ def test_convert_row_to_dict(self): self.assertEqual(1, row.asDict()["l"][0].a) self.assertEqual(1.0, row.asDict()['d']['key'].c) + def test_udt(self): + from pyspark.sql.types import _parse_datatype_json_string, _infer_type, _verify_type + from pyspark.sql.tests import ExamplePointUDT, ExamplePoint + + def check_datatype(datatype): + pickled = pickle.loads(pickle.dumps(datatype)) + assert datatype == pickled + scala_datatype = self.sqlCtx._ssql_ctx.parseDataType(datatype.json()) + python_datatype = _parse_datatype_json_string(scala_datatype.json()) + assert datatype == python_datatype + + check_datatype(ExamplePointUDT()) + structtype_with_udt = StructType([StructField("label", DoubleType(), False), + StructField("point", ExamplePointUDT(), False)]) + check_datatype(structtype_with_udt) + p = ExamplePoint(1.0, 2.0) + self.assertEqual(_infer_type(p), ExamplePointUDT()) + _verify_type(ExamplePoint(1.0, 2.0), ExamplePointUDT()) + self.assertRaises(ValueError, lambda: _verify_type([1.0, 2.0], ExamplePointUDT())) + + check_datatype(PythonOnlyUDT()) + structtype_with_udt = StructType([StructField("label", DoubleType(), False), + StructField("point", PythonOnlyUDT(), False)]) + check_datatype(structtype_with_udt) + p = PythonOnlyPoint(1.0, 2.0) + self.assertEqual(_infer_type(p), PythonOnlyUDT()) + _verify_type(PythonOnlyPoint(1.0, 2.0), PythonOnlyUDT()) + self.assertRaises(ValueError, lambda: _verify_type([1.0, 2.0], PythonOnlyUDT())) + def test_infer_schema_with_udt(self): from pyspark.sql.tests import ExamplePoint, ExamplePointUDT row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) - df = self.sc.parallelize([row]).toDF() + df = self.sqlCtx.createDataFrame([row]) schema = df.schema field = [f for f in schema.fields if f.name == "point"][0] self.assertEqual(type(field.dataType), ExamplePointUDT) @@ -406,36 +470,66 @@ def test_infer_schema_with_udt(self): point = self.sqlCtx.sql("SELECT point FROM labeled_point").head().point self.assertEqual(point, ExamplePoint(1.0, 2.0)) + row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0)) + df = self.sqlCtx.createDataFrame([row]) + schema = df.schema + field = [f for f in schema.fields if f.name == "point"][0] + self.assertEqual(type(field.dataType), PythonOnlyUDT) + df.registerTempTable("labeled_point") + point = self.sqlCtx.sql("SELECT point FROM labeled_point").head().point + self.assertEqual(point, PythonOnlyPoint(1.0, 2.0)) + def test_apply_schema_with_udt(self): from pyspark.sql.tests import ExamplePoint, ExamplePointUDT row = (1.0, ExamplePoint(1.0, 2.0)) - rdd = self.sc.parallelize([row]) schema = StructType([StructField("label", DoubleType(), False), StructField("point", ExamplePointUDT(), False)]) - df = rdd.toDF(schema) + df = self.sqlCtx.createDataFrame([row], schema) point = df.head().point self.assertEquals(point, ExamplePoint(1.0, 2.0)) + row = (1.0, PythonOnlyPoint(1.0, 2.0)) + schema = StructType([StructField("label", DoubleType(), False), + StructField("point", PythonOnlyUDT(), False)]) + df = self.sqlCtx.createDataFrame([row], schema) + point = df.head().point + self.assertEquals(point, PythonOnlyPoint(1.0, 2.0)) + def test_udf_with_udt(self): from pyspark.sql.tests import ExamplePoint, ExamplePointUDT row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) - df = self.sc.parallelize([row]).toDF() + df = self.sqlCtx.createDataFrame([row]) self.assertEqual(1.0, df.map(lambda r: r.point.x).first()) udf = UserDefinedFunction(lambda p: p.y, DoubleType()) self.assertEqual(2.0, df.select(udf(df.point)).first()[0]) udf2 = UserDefinedFunction(lambda p: ExamplePoint(p.x + 1, p.y + 1), ExamplePointUDT()) self.assertEqual(ExamplePoint(2.0, 3.0), df.select(udf2(df.point)).first()[0]) + row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0)) + df = self.sqlCtx.createDataFrame([row]) + self.assertEqual(1.0, df.map(lambda r: r.point.x).first()) + udf = UserDefinedFunction(lambda p: p.y, DoubleType()) + self.assertEqual(2.0, df.select(udf(df.point)).first()[0]) + udf2 = UserDefinedFunction(lambda p: PythonOnlyPoint(p.x + 1, p.y + 1), PythonOnlyUDT()) + self.assertEqual(PythonOnlyPoint(2.0, 3.0), df.select(udf2(df.point)).first()[0]) + def test_parquet_with_udt(self): - from pyspark.sql.tests import ExamplePoint + from pyspark.sql.tests import ExamplePoint, ExamplePointUDT row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) - df0 = self.sc.parallelize([row]).toDF() + df0 = self.sqlCtx.createDataFrame([row]) output_dir = os.path.join(self.tempdir.name, "labeled_point") - df0.saveAsParquetFile(output_dir) + df0.write.parquet(output_dir) df1 = self.sqlCtx.parquetFile(output_dir) point = df1.head().point self.assertEquals(point, ExamplePoint(1.0, 2.0)) + row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0)) + df0 = self.sqlCtx.createDataFrame([row]) + df0.write.parquet(output_dir, mode='overwrite') + df1 = self.sqlCtx.parquetFile(output_dir) + point = df1.head().point + self.assertEquals(point, PythonOnlyPoint(1.0, 2.0)) + def test_column_operators(self): ci = self.df.key cs = self.df.value diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index b97d50c945f24..6f74b7162f7cc 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -22,6 +22,7 @@ import calendar import json import re +import base64 from array import array if sys.version >= "3": @@ -31,6 +32,8 @@ from py4j.protocol import register_input_converter from py4j.java_gateway import JavaClass +from pyspark.serializers import CloudPickleSerializer + __all__ = [ "DataType", "NullType", "StringType", "BinaryType", "BooleanType", "DateType", "TimestampType", "DecimalType", "DoubleType", "FloatType", "ByteType", "IntegerType", @@ -458,7 +461,7 @@ def __init__(self, fields=None): self.names = [f.name for f in fields] assert all(isinstance(f, StructField) for f in fields),\ "fields should be a list of StructField" - self._needSerializeFields = None + self._needSerializeAnyField = any(f.needConversion() for f in self.fields) def add(self, field, data_type=None, nullable=True, metadata=None): """ @@ -501,6 +504,7 @@ def add(self, field, data_type=None, nullable=True, metadata=None): data_type_f = data_type self.fields.append(StructField(field, data_type_f, nullable, metadata)) self.names.append(field) + self._needSerializeAnyField = any(f.needConversion() for f in self.fields) return self def simpleString(self): @@ -526,12 +530,9 @@ def toInternal(self, obj): if obj is None: return - if self._needSerializeFields is None: - self._needSerializeFields = any(f.needConversion() for f in self.fields) - - if self._needSerializeFields: + if self._needSerializeAnyField: if isinstance(obj, dict): - return tuple(f.toInternal(obj.get(n)) for n, f in zip(names, self.fields)) + return tuple(f.toInternal(obj.get(n)) for n, f in zip(self.names, self.fields)) elif isinstance(obj, (tuple, list)): return tuple(f.toInternal(v) for f, v in zip(self.fields, obj)) else: @@ -550,7 +551,10 @@ def fromInternal(self, obj): if isinstance(obj, Row): # it's already converted by pickler return obj - values = [f.dataType.fromInternal(v) for f, v in zip(self.fields, obj)] + if self._needSerializeAnyField: + values = [f.fromInternal(v) for f, v in zip(self.fields, obj)] + else: + values = obj return _create_row(self.names, values) @@ -581,9 +585,10 @@ def module(cls): @classmethod def scalaUDT(cls): """ - The class name of the paired Scala UDT. + The class name of the paired Scala UDT (could be '', if there + is no corresponding one). """ - raise NotImplementedError("UDT must have a paired Scala UDT.") + return '' def needConversion(self): return True @@ -622,22 +627,37 @@ def json(self): return json.dumps(self.jsonValue(), separators=(',', ':'), sort_keys=True) def jsonValue(self): - schema = { - "type": "udt", - "class": self.scalaUDT(), - "pyClass": "%s.%s" % (self.module(), type(self).__name__), - "sqlType": self.sqlType().jsonValue() - } + if self.scalaUDT(): + assert self.module() != '__main__', 'UDT in __main__ cannot work with ScalaUDT' + schema = { + "type": "udt", + "class": self.scalaUDT(), + "pyClass": "%s.%s" % (self.module(), type(self).__name__), + "sqlType": self.sqlType().jsonValue() + } + else: + ser = CloudPickleSerializer() + b = ser.dumps(type(self)) + schema = { + "type": "udt", + "pyClass": "%s.%s" % (self.module(), type(self).__name__), + "serializedClass": base64.b64encode(b).decode('utf8'), + "sqlType": self.sqlType().jsonValue() + } return schema @classmethod def fromJson(cls, json): - pyUDT = json["pyClass"] + pyUDT = str(json["pyClass"]) # convert unicode to str split = pyUDT.rfind(".") pyModule = pyUDT[:split] pyClass = pyUDT[split+1:] m = __import__(pyModule, globals(), locals(), [pyClass]) - UDT = getattr(m, pyClass) + if not hasattr(m, pyClass): + s = base64.b64decode(json['serializedClass'].encode('utf-8')) + UDT = CloudPickleSerializer().loads(s) + else: + UDT = getattr(m, pyClass) return UDT() def __eq__(self, other): @@ -696,11 +716,6 @@ def _parse_datatype_json_string(json_string): >>> complex_maptype = MapType(complex_structtype, ... complex_arraytype, False) >>> check_datatype(complex_maptype) - - >>> check_datatype(ExamplePointUDT()) - >>> structtype_with_udt = StructType([StructField("label", DoubleType(), False), - ... StructField("point", ExamplePointUDT(), False)]) - >>> check_datatype(structtype_with_udt) """ return _parse_datatype_json_value(json.loads(json_string)) @@ -752,10 +767,6 @@ def _parse_datatype_json_value(json_value): def _infer_type(obj): """Infer the DataType from obj - - >>> p = ExamplePoint(1.0, 2.0) - >>> _infer_type(p) - ExamplePointUDT """ if obj is None: return NullType() @@ -1090,11 +1101,6 @@ def _verify_type(obj, dataType): Traceback (most recent call last): ... ValueError:... - >>> _verify_type(ExamplePoint(1.0, 2.0), ExamplePointUDT()) - >>> _verify_type([1.0, 2.0], ExamplePointUDT()) # doctest: +IGNORE_EXCEPTION_DETAIL - Traceback (most recent call last): - ... - ValueError:... """ # all objects are nullable if obj is None: @@ -1259,18 +1265,12 @@ def convert(self, obj, gateway_client): def _test(): import doctest from pyspark.context import SparkContext - # let doctest run in pyspark.sql.types, so DataTypes can be picklable - import pyspark.sql.types - from pyspark.sql import Row, SQLContext - from pyspark.sql.tests import ExamplePoint, ExamplePointUDT - globs = pyspark.sql.types.__dict__.copy() + from pyspark.sql import SQLContext + globs = globals() sc = SparkContext('local[4]', 'PythonTest') globs['sc'] = sc globs['sqlContext'] = SQLContext(sc) - globs['ExamplePoint'] = ExamplePoint - globs['ExamplePointUDT'] = ExamplePointUDT - (failure_count, test_count) = doctest.testmod( - pyspark.sql.types, globs=globs, optionflags=doctest.ELLIPSIS) + (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) globs['sc'].stop() if failure_count: exit(-1) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGetters.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGetters.java new file mode 100644 index 0000000000000..e3d3ba7a9ccc0 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGetters.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions; + +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.types.ArrayData; +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.unsafe.types.CalendarInterval; +import org.apache.spark.unsafe.types.UTF8String; + +public interface SpecializedGetters { + + boolean isNullAt(int ordinal); + + boolean getBoolean(int ordinal); + + byte getByte(int ordinal); + + short getShort(int ordinal); + + int getInt(int ordinal); + + long getLong(int ordinal); + + float getFloat(int ordinal); + + double getDouble(int ordinal); + + Decimal getDecimal(int ordinal, int precision, int scale); + + UTF8String getUTF8String(int ordinal); + + byte[] getBinary(int ordinal); + + CalendarInterval getInterval(int ordinal); + + InternalRow getStruct(int ordinal, int numFields); + + ArrayData getArray(int ordinal); +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java index 684de6e81d67c..f3b462778dc10 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java @@ -20,6 +20,8 @@ import java.util.Iterator; import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.sql.types.DecimalType; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; import org.apache.spark.unsafe.PlatformDependent; @@ -61,26 +63,18 @@ public final class UnsafeFixedWidthAggregationMap { private final boolean enablePerfMetrics; - /** - * @return true if UnsafeFixedWidthAggregationMap supports grouping keys with the given schema, - * false otherwise. - */ - public static boolean supportsGroupKeySchema(StructType schema) { - for (StructField field: schema.fields()) { - if (!UnsafeRow.readableFieldTypes.contains(field.dataType())) { - return false; - } - } - return true; - } - /** * @return true if UnsafeFixedWidthAggregationMap supports aggregation buffers with the given * schema, false otherwise. */ public static boolean supportsAggregationBufferSchema(StructType schema) { for (StructField field: schema.fields()) { - if (!UnsafeRow.settableFieldTypes.contains(field.dataType())) { + if (field.dataType() instanceof DecimalType) { + DecimalType dt = (DecimalType) field.dataType(); + if (dt.precision() > Decimal.MAX_LONG_DIGITS()) { + return false; + } + } else if (!UnsafeRow.settableFieldTypes.contains(field.dataType())) { return false; } } @@ -95,6 +89,7 @@ public static boolean supportsAggregationBufferSchema(StructType schema) { * @param groupingKeySchema the schema of the grouping key, used for row conversion. * @param memoryManager the memory manager used to allocate our Unsafe memory structures. * @param initialCapacity the initial capacity of the map (a sizing hint to avoid re-hashing). + * @param pageSizeBytes the data page size, in bytes; limits the maximum record size. * @param enablePerfMetrics if true, performance metrics will be recorded (has minor perf impact) */ public UnsafeFixedWidthAggregationMap( @@ -103,11 +98,13 @@ public UnsafeFixedWidthAggregationMap( StructType groupingKeySchema, TaskMemoryManager memoryManager, int initialCapacity, + long pageSizeBytes, boolean enablePerfMetrics) { this.aggregationBufferSchema = aggregationBufferSchema; this.groupingKeyProjection = UnsafeProjection.create(groupingKeySchema); this.groupingKeySchema = groupingKeySchema; - this.map = new BytesToBytesMap(memoryManager, initialCapacity, enablePerfMetrics); + this.map = + new BytesToBytesMap(memoryManager, initialCapacity, pageSizeBytes, enablePerfMetrics); this.enablePerfMetrics = enablePerfMetrics; // Initialize the buffer for aggregation value diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index fb084dd13b620..e7088edced1a1 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -19,6 +19,8 @@ import java.io.IOException; import java.io.OutputStream; +import java.math.BigDecimal; +import java.math.BigInteger; import java.util.Arrays; import java.util.Collections; import java.util.HashSet; @@ -29,7 +31,7 @@ import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.bitset.BitSetMethods; import org.apache.spark.unsafe.hash.Murmur3_x86_32; -import org.apache.spark.unsafe.types.Interval; +import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; import static org.apache.spark.sql.types.DataTypes.*; @@ -65,12 +67,7 @@ public static int calculateBitSetWidthInBytes(int numFields) { */ public static final Set settableFieldTypes; - /** - * Fields types can be read(but not set (e.g. set() will throw UnsupportedOperationException). - */ - public static final Set readableFieldTypes; - - // TODO: support DecimalType + // DecimalType(precision <= 18) is settable static { settableFieldTypes = Collections.unmodifiableSet( new HashSet<>( @@ -86,16 +83,6 @@ public static int calculateBitSetWidthInBytes(int numFields) { DateType, TimestampType }))); - - // We support get() on a superset of the types for which we support set(): - final Set _readableFieldTypes = new HashSet<>( - Arrays.asList(new DataType[]{ - StringType, - BinaryType, - IntervalType - })); - _readableFieldTypes.addAll(settableFieldTypes); - readableFieldTypes = Collections.unmodifiableSet(_readableFieldTypes); } ////////////////////////////////////////////////////////////////////////////// @@ -232,6 +219,21 @@ public void setFloat(int ordinal, float value) { PlatformDependent.UNSAFE.putFloat(baseObject, getFieldOffset(ordinal), value); } + @Override + public void setDecimal(int ordinal, Decimal value, int precision) { + assertIndexIsValid(ordinal); + if (value == null) { + setNullAt(ordinal); + } else { + if (precision <= Decimal.MAX_LONG_DIGITS()) { + setLong(ordinal, value.toUnscaledLong()); + } else { + // TODO(davies): support update decimal (hold a bounded space even it's null) + throw new UnsupportedOperationException(); + } + } + } + @Override public Object get(int ordinal) { throw new UnsupportedOperationException(); @@ -239,7 +241,7 @@ public Object get(int ordinal) { @Override public Object get(int ordinal, DataType dataType) { - if (dataType instanceof NullType) { + if (isNullAt(ordinal) || dataType instanceof NullType) { return null; } else if (dataType instanceof BooleanType) { return getBoolean(ordinal); @@ -256,7 +258,8 @@ public Object get(int ordinal, DataType dataType) { } else if (dataType instanceof DoubleType) { return getDouble(ordinal); } else if (dataType instanceof DecimalType) { - return getDecimal(ordinal); + DecimalType dt = (DecimalType) dataType; + return getDecimal(ordinal, dt.precision(), dt.scale()); } else if (dataType instanceof DateType) { return getInt(ordinal); } else if (dataType instanceof TimestampType) { @@ -265,6 +268,8 @@ public Object get(int ordinal, DataType dataType) { return getBinary(ordinal); } else if (dataType instanceof StringType) { return getUTF8String(ordinal); + } else if (dataType instanceof CalendarIntervalType) { + return getInterval(ordinal); } else if (dataType instanceof StructType) { return getStruct(ordinal, ((StructType) dataType).size()); } else { @@ -311,20 +316,28 @@ public long getLong(int ordinal) { @Override public float getFloat(int ordinal) { assertIndexIsValid(ordinal); - if (isNullAt(ordinal)) { - return Float.NaN; - } else { - return PlatformDependent.UNSAFE.getFloat(baseObject, getFieldOffset(ordinal)); - } + return PlatformDependent.UNSAFE.getFloat(baseObject, getFieldOffset(ordinal)); } @Override public double getDouble(int ordinal) { + assertIndexIsValid(ordinal); + return PlatformDependent.UNSAFE.getDouble(baseObject, getFieldOffset(ordinal)); + } + + @Override + public Decimal getDecimal(int ordinal, int precision, int scale) { assertIndexIsValid(ordinal); if (isNullAt(ordinal)) { - return Float.NaN; + return null; + } + if (precision <= Decimal.MAX_LONG_DIGITS()) { + return Decimal.apply(getLong(ordinal), precision, scale); } else { - return PlatformDependent.UNSAFE.getDouble(baseObject, getFieldOffset(ordinal)); + byte[] bytes = getBinary(ordinal); + BigInteger bigInteger = new BigInteger(bytes); + BigDecimal javaDecimal = new BigDecimal(bigInteger, scale); + return Decimal.apply(new scala.math.BigDecimal(javaDecimal), precision, scale); } } @@ -356,7 +369,7 @@ public byte[] getBinary(int ordinal) { } @Override - public Interval getInterval(int ordinal) { + public CalendarInterval getInterval(int ordinal) { if (isNullAt(ordinal)) { return null; } else { @@ -365,7 +378,7 @@ public Interval getInterval(int ordinal) { final int months = (int) PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + offset); final long microseconds = PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + offset + 8); - return new Interval(months, microseconds); + return new CalendarInterval(months, microseconds); } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java index 0ba31d3b9b743..f43a285cd6cad 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java @@ -17,10 +17,12 @@ package org.apache.spark.sql.catalyst.expressions; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.types.Decimal; import org.apache.spark.unsafe.PlatformDependent; import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.types.ByteArray; -import org.apache.spark.unsafe.types.Interval; +import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; /** @@ -29,6 +31,47 @@ */ public class UnsafeRowWriters { + /** Writer for Decimal with precision under 18. */ + public static class CompactDecimalWriter { + + public static int getSize(Decimal input) { + return 0; + } + + public static int write(UnsafeRow target, int ordinal, int cursor, Decimal input) { + target.setLong(ordinal, input.toUnscaledLong()); + return 0; + } + } + + /** Writer for Decimal with precision larger than 18. */ + public static class DecimalWriter { + + public static int getSize(Decimal input) { + // bounded size + return 16; + } + + public static int write(UnsafeRow target, int ordinal, int cursor, Decimal input) { + final long offset = target.getBaseOffset() + cursor; + final byte[] bytes = input.toJavaBigDecimal().unscaledValue().toByteArray(); + final int numBytes = bytes.length; + assert(numBytes <= 16); + + // zero-out the bytes + PlatformDependent.UNSAFE.putLong(target.getBaseObject(), offset, 0L); + PlatformDependent.UNSAFE.putLong(target.getBaseObject(), offset + 8, 0L); + + // Write the bytes to the variable length portion. + PlatformDependent.copyMemory(bytes, PlatformDependent.BYTE_ARRAY_OFFSET, + target.getBaseObject(), offset, numBytes); + + // Set the fixed length portion. + target.setLong(ordinal, (((long) cursor) << 32) | ((long) numBytes)); + return 16; + } + } + /** Writer for UTF8String. */ public static class UTF8StringWriter { @@ -46,7 +89,7 @@ public static int write(UnsafeRow target, int ordinal, int cursor, UTF8String in target.getBaseObject(), offset + ((numBytes >> 3) << 3), 0L); } - // Write the string to the variable length portion. + // Write the bytes to the variable length portion. input.writeToMemory(target.getBaseObject(), offset); // Set the fixed length portion. @@ -72,7 +115,7 @@ public static int write(UnsafeRow target, int ordinal, int cursor, byte[] input) target.getBaseObject(), offset + ((numBytes >> 3) << 3), 0L); } - // Write the string to the variable length portion. + // Write the bytes to the variable length portion. ByteArray.writeToMemory(input, target.getBaseObject(), offset); // Set the fixed length portion. @@ -81,10 +124,56 @@ public static int write(UnsafeRow target, int ordinal, int cursor, byte[] input) } } + /** + * Writer for struct type where the struct field is backed by an {@link UnsafeRow}. + * + * We throw UnsupportedOperationException for inputs that are not backed by {@link UnsafeRow}. + * Non-UnsafeRow struct fields are handled directly in + * {@link org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection} + * by generating the Java code needed to convert them into UnsafeRow. + */ + public static class StructWriter { + public static int getSize(InternalRow input) { + int numBytes = 0; + if (input instanceof UnsafeRow) { + numBytes = ((UnsafeRow) input).getSizeInBytes(); + } else { + // This is handled directly in GenerateUnsafeProjection. + throw new UnsupportedOperationException(); + } + return ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes); + } + + public static int write(UnsafeRow target, int ordinal, int cursor, InternalRow input) { + int numBytes = 0; + final long offset = target.getBaseOffset() + cursor; + if (input instanceof UnsafeRow) { + final UnsafeRow row = (UnsafeRow) input; + numBytes = row.getSizeInBytes(); + + // zero-out the padding bytes + if ((numBytes & 0x07) > 0) { + PlatformDependent.UNSAFE.putLong( + target.getBaseObject(), offset + ((numBytes >> 3) << 3), 0L); + } + + // Write the bytes to the variable length portion. + row.writeToMemory(target.getBaseObject(), offset); + + // Set the fixed length portion. + target.setLong(ordinal, (((long) cursor) << 32) | ((long) numBytes)); + } else { + // This is handled directly in GenerateUnsafeProjection. + throw new UnsupportedOperationException(); + } + return ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes); + } + } + /** Writer for interval type. */ public static class IntervalWriter { - public static int write(UnsafeRow target, int ordinal, int cursor, Interval input) { + public static int write(UnsafeRow target, int ordinal, int cursor, CalendarInterval input) { final long offset = target.getBaseOffset() + cursor; // Write the months and microseconds fields of Interval to the variable length portion. @@ -96,5 +185,4 @@ public static int write(UnsafeRow target, int ordinal, int cursor, Interval inpu return 16; } } - } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java index 4c3f2c6557140..68c49feae938e 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java @@ -48,7 +48,6 @@ final class UnsafeExternalRowSorter { private long numRowsInserted = 0; private final StructType schema; - private final UnsafeProjection unsafeProjection; private final PrefixComputer prefixComputer; private final UnsafeExternalSorter sorter; @@ -62,7 +61,6 @@ public UnsafeExternalRowSorter( PrefixComparator prefixComparator, PrefixComputer prefixComputer) throws IOException { this.schema = schema; - this.unsafeProjection = UnsafeProjection.create(schema); this.prefixComputer = prefixComputer; final SparkEnv sparkEnv = SparkEnv.get(); final TaskContext taskContext = TaskContext.get(); @@ -88,13 +86,12 @@ void setTestSpillFrequency(int frequency) { } @VisibleForTesting - void insertRow(InternalRow row) throws IOException { - UnsafeRow unsafeRow = unsafeProjection.apply(row); + void insertRow(UnsafeRow row) throws IOException { final long prefix = prefixComputer.computePrefix(row); sorter.insertRecord( - unsafeRow.getBaseObject(), - unsafeRow.getBaseOffset(), - unsafeRow.getSizeInBytes(), + row.getBaseObject(), + row.getBaseOffset(), + row.getSizeInBytes(), prefix ); numRowsInserted++; @@ -113,7 +110,7 @@ private void cleanupResources() { } @VisibleForTesting - Iterator sort() throws IOException { + Iterator sort() throws IOException { try { final UnsafeSorterIterator sortedIterator = sorter.getSortedIterator(); if (!sortedIterator.hasNext()) { @@ -121,7 +118,7 @@ Iterator sort() throws IOException { // here in order to prevent memory leaks. cleanupResources(); } - return new AbstractScalaRowIterator() { + return new AbstractScalaRowIterator() { private final int numFields = schema.length(); private UnsafeRow row = new UnsafeRow(); @@ -132,7 +129,7 @@ public boolean hasNext() { } @Override - public InternalRow next() { + public UnsafeRow next() { try { sortedIterator.loadNext(); row.pointTo( @@ -164,11 +161,11 @@ public InternalRow next() { } - public Iterator sort(Iterator inputIterator) throws IOException { - while (inputIterator.hasNext()) { - insertRow(inputIterator.next()); - } - return sort(); + public Iterator sort(Iterator inputIterator) throws IOException { + while (inputIterator.hasNext()) { + insertRow(inputIterator.next()); + } + return sort(); } /** diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/types/DataTypes.java b/sql/catalyst/src/main/java/org/apache/spark/sql/types/DataTypes.java index 5703de42393de..17659d7d960b0 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/types/DataTypes.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/types/DataTypes.java @@ -50,9 +50,9 @@ public class DataTypes { public static final DataType TimestampType = TimestampType$.MODULE$; /** - * Gets the IntervalType object. + * Gets the CalendarIntervalType object. */ - public static final DataType IntervalType = IntervalType$.MODULE$; + public static final DataType CalendarIntervalType = CalendarIntervalType$.MODULE$; /** * Gets the DoubleType object. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala index d1d89a1f48329..7ca20fe97fbef 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala @@ -55,7 +55,6 @@ object CatalystTypeConverters { private def isWholePrimitive(dt: DataType): Boolean = dt match { case dt if isPrimitive(dt) => true - case ArrayType(elementType, _) => isWholePrimitive(elementType) case MapType(keyType, valueType, _) => isWholePrimitive(keyType) && isWholePrimitive(valueType) case _ => false } @@ -69,7 +68,7 @@ object CatalystTypeConverters { case StringType => StringConverter case DateType => DateConverter case TimestampType => TimestampConverter - case dt: DecimalType => BigDecimalConverter + case dt: DecimalType => new DecimalConverter(dt) case BooleanType => BooleanConverter case ByteType => ByteConverter case ShortType => ShortConverter @@ -154,39 +153,41 @@ object CatalystTypeConverters { /** Converter for arrays, sequences, and Java iterables. */ private case class ArrayConverter( - elementType: DataType) extends CatalystTypeConverter[Any, Seq[Any], Seq[Any]] { + elementType: DataType) extends CatalystTypeConverter[Any, Seq[Any], ArrayData] { private[this] val elementConverter = getConverterForType(elementType) private[this] val isNoChange = isWholePrimitive(elementType) - override def toCatalystImpl(scalaValue: Any): Seq[Any] = { + override def toCatalystImpl(scalaValue: Any): ArrayData = { scalaValue match { - case a: Array[_] => a.toSeq.map(elementConverter.toCatalyst) - case s: Seq[_] => s.map(elementConverter.toCatalyst) + case a: Array[_] => + new GenericArrayData(a.map(elementConverter.toCatalyst)) + case s: Seq[_] => + new GenericArrayData(s.map(elementConverter.toCatalyst).toArray) case i: JavaIterable[_] => val iter = i.iterator - var convertedIterable: List[Any] = List() + val convertedIterable = scala.collection.mutable.ArrayBuffer.empty[Any] while (iter.hasNext) { val item = iter.next() - convertedIterable :+= elementConverter.toCatalyst(item) + convertedIterable += elementConverter.toCatalyst(item) } - convertedIterable + new GenericArrayData(convertedIterable.toArray) } } - override def toScala(catalystValue: Seq[Any]): Seq[Any] = { + override def toScala(catalystValue: ArrayData): Seq[Any] = { if (catalystValue == null) { null } else if (isNoChange) { - catalystValue + catalystValue.toArray() } else { - catalystValue.map(elementConverter.toScala) + catalystValue.toArray().map(elementConverter.toScala) } } override def toScalaImpl(row: InternalRow, column: Int): Seq[Any] = - toScala(row.get(column, ArrayType(elementType)).asInstanceOf[Seq[Any]]) + toScala(row.getArray(column)) } private case class MapConverter( @@ -305,7 +306,8 @@ object CatalystTypeConverters { DateTimeUtils.toJavaTimestamp(row.getLong(column)) } - private object BigDecimalConverter extends CatalystTypeConverter[Any, JavaBigDecimal, Decimal] { + private class DecimalConverter(dataType: DecimalType) + extends CatalystTypeConverter[Any, JavaBigDecimal, Decimal] { override def toCatalystImpl(scalaValue: Any): Decimal = scalaValue match { case d: BigDecimal => Decimal(d) case d: JavaBigDecimal => Decimal(d) @@ -313,9 +315,11 @@ object CatalystTypeConverters { } override def toScala(catalystValue: Decimal): JavaBigDecimal = catalystValue.toJavaBigDecimal override def toScalaImpl(row: InternalRow, column: Int): JavaBigDecimal = - row.getDecimal(column).toJavaBigDecimal + row.getDecimal(column, dataType.precision, dataType.scale).toJavaBigDecimal } + private object BigDecimalConverter extends DecimalConverter(DecimalType.SYSTEM_DEFAULT) + private abstract class PrimitiveConverter[T] extends CatalystTypeConverter[T, Any, Any] { final override def toScala(catalystValue: Any): Any = catalystValue final override def toCatalystImpl(scalaValue: T): Any = scalaValue @@ -402,9 +406,9 @@ object CatalystTypeConverters { case t: Timestamp => TimestampConverter.toCatalyst(t) case d: BigDecimal => BigDecimalConverter.toCatalyst(d) case d: JavaBigDecimal => BigDecimalConverter.toCatalyst(d) - case seq: Seq[Any] => seq.map(convertToCatalyst) + case seq: Seq[Any] => new GenericArrayData(seq.map(convertToCatalyst).toArray) case r: Row => InternalRow(r.toSeq.map(convertToCatalyst): _*) - case arr: Array[Any] => arr.map(convertToCatalyst) + case arr: Array[Any] => new GenericArrayData(arr.map(convertToCatalyst)) case m: Map[_, _] => m.map { case (k, v) => (convertToCatalyst(k), convertToCatalyst(v)) }.toMap case other => other diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala index 9a11de3840ce2..b19bf4386b0ba 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala @@ -20,13 +20,13 @@ package org.apache.spark.sql.catalyst import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.{Interval, UTF8String} +import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} /** * An abstract class for row used internal in Spark SQL, which only contain the columns as * internal types. */ -abstract class InternalRow extends Serializable { +abstract class InternalRow extends Serializable with SpecializedGetters { def numFields: Int @@ -38,29 +38,31 @@ abstract class InternalRow extends Serializable { def getAs[T](ordinal: Int, dataType: DataType): T = get(ordinal, dataType).asInstanceOf[T] - def isNullAt(ordinal: Int): Boolean = get(ordinal) == null + override def isNullAt(ordinal: Int): Boolean = get(ordinal) == null - def getBoolean(ordinal: Int): Boolean = getAs[Boolean](ordinal, BooleanType) + override def getBoolean(ordinal: Int): Boolean = getAs[Boolean](ordinal, BooleanType) - def getByte(ordinal: Int): Byte = getAs[Byte](ordinal, ByteType) + override def getByte(ordinal: Int): Byte = getAs[Byte](ordinal, ByteType) - def getShort(ordinal: Int): Short = getAs[Short](ordinal, ShortType) + override def getShort(ordinal: Int): Short = getAs[Short](ordinal, ShortType) - def getInt(ordinal: Int): Int = getAs[Int](ordinal, IntegerType) + override def getInt(ordinal: Int): Int = getAs[Int](ordinal, IntegerType) - def getLong(ordinal: Int): Long = getAs[Long](ordinal, LongType) + override def getLong(ordinal: Int): Long = getAs[Long](ordinal, LongType) - def getFloat(ordinal: Int): Float = getAs[Float](ordinal, FloatType) + override def getFloat(ordinal: Int): Float = getAs[Float](ordinal, FloatType) - def getDouble(ordinal: Int): Double = getAs[Double](ordinal, DoubleType) + override def getDouble(ordinal: Int): Double = getAs[Double](ordinal, DoubleType) - def getUTF8String(ordinal: Int): UTF8String = getAs[UTF8String](ordinal, StringType) + override def getUTF8String(ordinal: Int): UTF8String = getAs[UTF8String](ordinal, StringType) - def getBinary(ordinal: Int): Array[Byte] = getAs[Array[Byte]](ordinal, BinaryType) + override def getBinary(ordinal: Int): Array[Byte] = getAs[Array[Byte]](ordinal, BinaryType) - def getDecimal(ordinal: Int): Decimal = getAs[Decimal](ordinal, DecimalType.SYSTEM_DEFAULT) + override def getDecimal(ordinal: Int, precision: Int, scale: Int): Decimal = + getAs[Decimal](ordinal, DecimalType(precision, scale)) - def getInterval(ordinal: Int): Interval = getAs[Interval](ordinal, IntervalType) + override def getInterval(ordinal: Int): CalendarInterval = + getAs[CalendarInterval](ordinal, CalendarIntervalType) // This is only use for test and will throw a null pointer exception if the position is null. def getString(ordinal: Int): String = getUTF8String(ordinal).toString @@ -71,7 +73,10 @@ abstract class InternalRow extends Serializable { * @param ordinal position to get the struct from. * @param numFields number of fields the struct type has */ - def getStruct(ordinal: Int, numFields: Int): InternalRow = getAs[InternalRow](ordinal, null) + override def getStruct(ordinal: Int, numFields: Int): InternalRow = + getAs[InternalRow](ordinal, null) + + override def getArray(ordinal: Int): ArrayData = getAs(ordinal, null) override def toString: String = s"[${this.mkString(",")}]" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index b423f0fa04f69..f2498861c9573 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.Interval +import org.apache.spark.unsafe.types.CalendarInterval /** * A very simple SQL parser. Based loosely on: @@ -332,8 +332,7 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { protected lazy val numericLiteral: Parser[Literal] = ( integral ^^ { case i => Literal(toNarrowestIntegerType(i)) } | sign.? ~ unsignedFloat ^^ { - // TODO(davies): some precisions may loss, we should create decimal literal - case s ~ f => Literal(BigDecimal(s.getOrElse("") + f).doubleValue()) + case s ~ f => Literal(toDecimalOrDouble(s.getOrElse("") + f)) } ) @@ -366,32 +365,32 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { protected lazy val millisecond: Parser[Long] = integral <~ intervalUnit("millisecond") ^^ { - case num => num.toLong * Interval.MICROS_PER_MILLI + case num => num.toLong * CalendarInterval.MICROS_PER_MILLI } protected lazy val second: Parser[Long] = integral <~ intervalUnit("second") ^^ { - case num => num.toLong * Interval.MICROS_PER_SECOND + case num => num.toLong * CalendarInterval.MICROS_PER_SECOND } protected lazy val minute: Parser[Long] = integral <~ intervalUnit("minute") ^^ { - case num => num.toLong * Interval.MICROS_PER_MINUTE + case num => num.toLong * CalendarInterval.MICROS_PER_MINUTE } protected lazy val hour: Parser[Long] = integral <~ intervalUnit("hour") ^^ { - case num => num.toLong * Interval.MICROS_PER_HOUR + case num => num.toLong * CalendarInterval.MICROS_PER_HOUR } protected lazy val day: Parser[Long] = integral <~ intervalUnit("day") ^^ { - case num => num.toLong * Interval.MICROS_PER_DAY + case num => num.toLong * CalendarInterval.MICROS_PER_DAY } protected lazy val week: Parser[Long] = integral <~ intervalUnit("week") ^^ { - case num => num.toLong * Interval.MICROS_PER_WEEK + case num => num.toLong * CalendarInterval.MICROS_PER_WEEK } protected lazy val intervalLiteral: Parser[Literal] = @@ -407,7 +406,7 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { val months = Seq(year, month).map(_.getOrElse(0)).sum val microseconds = Seq(week, day, hour, minute, second, millisecond, microsecond) .map(_.getOrElse(0L)).sum - Literal.create(new Interval(months, microseconds), IntervalType) + Literal.create(new CalendarInterval(months, microseconds), CalendarIntervalType) } private def toNarrowestIntegerType(value: String): Any = { @@ -420,6 +419,17 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { } } + private def toDecimalOrDouble(value: String): Any = { + val decimal = BigDecimal(value) + // follow the behavior in MS SQL Server + // https://msdn.microsoft.com/en-us/library/ms179899.aspx + if (value.contains('E') || value.contains('e')) { + decimal.doubleValue() + } else { + decimal.underlying() + } + } + protected lazy val baseExpression: Parser[Expression] = ( "*" ^^^ UnresolvedStar(None) | ident <~ "." ~ "*" ^^ { case tableName => UnresolvedStar(Option(tableName)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index a723e92114b32..265f3d1e41765 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.analysis +import scala.collection.mutable.ArrayBuffer + import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions.aggregate.{Complete, AggregateExpression2, AggregateFunction2} import org.apache.spark.sql.catalyst.expressions._ @@ -25,7 +27,6 @@ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.catalyst.trees.TreeNodeRef import org.apache.spark.sql.catalyst.{SimpleCatalystConf, CatalystConf} import org.apache.spark.sql.types._ -import scala.collection.mutable.ArrayBuffer /** * A trivial [[Analyzer]] with an [[EmptyCatalog]] and [[EmptyFunctionRegistry]]. Used for testing @@ -78,6 +79,7 @@ class Analyzer( ExtractWindowExpressions :: GlobalAggregates :: UnresolvedHavingClauseAttributes :: + RemoveEvaluationFromSort :: HiveTypeCoercion.typeCoercionRules ++ extendedResolutionRules : _*), Batch("Nondeterministic", Once, @@ -927,12 +929,17 @@ class Analyzer( // from LogicalPlan, currently we only do it for UnaryNode which has same output // schema with its child. case p: UnaryNode if p.output == p.child.output && p.expressions.exists(!_.deterministic) => - val nondeterministicExprs = p.expressions.filterNot(_.deterministic).map { e => - val ne = e match { - case n: NamedExpression => n - case _ => Alias(e, "_nondeterministic")() + val nondeterministicExprs = p.expressions.filterNot(_.deterministic).flatMap { expr => + val leafNondeterministic = expr.collect { + case n: Nondeterministic => n + } + leafNondeterministic.map { e => + val ne = e match { + case n: NamedExpression => n + case _ => Alias(e, "_nondeterministic")() + } + new TreeNodeRef(e) -> ne } - new TreeNodeRef(e) -> ne }.toMap val newPlan = p.transformExpressions { case e => nondeterministicExprs.get(new TreeNodeRef(e)).map(_.toAttribute).getOrElse(e) @@ -941,6 +948,63 @@ class Analyzer( Project(p.output, newPlan.withNewChildren(newChild :: Nil)) } } + + /** + * Removes all still-need-evaluate ordering expressions from sort and use an inner project to + * materialize them, finally use a outer project to project them away to keep the result same. + * Then we can make sure we only sort by [[AttributeReference]]s. + * + * As an example, + * {{{ + * Sort('a, 'b + 1, + * Relation('a, 'b)) + * }}} + * will be turned into: + * {{{ + * Project('a, 'b, + * Sort('a, '_sortCondition, + * Project('a, 'b, ('b + 1).as("_sortCondition"), + * Relation('a, 'b)))) + * }}} + */ + object RemoveEvaluationFromSort extends Rule[LogicalPlan] { + private def hasAlias(expr: Expression) = { + expr.find { + case a: Alias => true + case _ => false + }.isDefined + } + + override def apply(plan: LogicalPlan): LogicalPlan = plan transform { + // The ordering expressions have no effect to the output schema of `Sort`, + // so `Alias`s in ordering expressions are unnecessary and we should remove them. + case s @ Sort(ordering, _, _) if ordering.exists(hasAlias) => + val newOrdering = ordering.map(_.transformUp { + case Alias(child, _) => child + }.asInstanceOf[SortOrder]) + s.copy(order = newOrdering) + + case s @ Sort(ordering, global, child) + if s.expressions.forall(_.resolved) && s.childrenResolved && !s.hasNoEvaluation => + + val (ref, needEval) = ordering.partition(_.child.isInstanceOf[AttributeReference]) + + val namedExpr = needEval.map(_.child match { + case n: NamedExpression => n + case e => Alias(e, "_sortCondition")() + }) + + val newOrdering = ref ++ needEval.zip(namedExpr).map { case (order, ne) => + order.copy(child = ne.toAttribute) + } + + // Add still-need-evaluate ordering expressions into inner project and then project + // them away after the sort. + Project(child.output, + Sort(newOrdering, global, + Project(child.output ++ namedExpr, child))) + } + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index a373714832962..0ebc3d180a780 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -87,6 +87,18 @@ trait CheckAnalysis { s"join condition '${condition.prettyString}' " + s"of type ${condition.dataType.simpleString} is not a boolean.") + case j @ Join(_, _, _, Some(condition)) => + def checkValidJoinConditionExprs(expr: Expression): Unit = expr match { + case p: Predicate => + p.asInstanceOf[Expression].children.foreach(checkValidJoinConditionExprs) + case e if e.dataType.isInstanceOf[BinaryType] => + failAnalysis(s"expression ${e.prettyString} in join condition " + + s"'${condition.prettyString}' can't be binary type.") + case _ => // OK + } + + checkValidJoinConditionExprs(condition) + case Aggregate(groupingExprs, aggregateExprs, child) => def checkValidAggregateExpression(expr: Expression): Unit = expr match { case _: AggregateExpression => // OK @@ -100,7 +112,15 @@ trait CheckAnalysis { case e => e.children.foreach(checkValidAggregateExpression) } + def checkValidGroupingExprs(expr: Expression): Unit = expr.dataType match { + case BinaryType => + failAnalysis(s"grouping expression '${expr.prettyString}' in aggregate can " + + s"not be binary type.") + case _ => // OK + } + aggregateExprs.foreach(checkValidAggregateExpression) + aggregateExprs.foreach(checkValidGroupingExprs) case Sort(orders, _, _) => orders.foreach { order => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index aa05f448d12bc..1bf7204a2515c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -161,13 +161,6 @@ object FunctionRegistry { expression[ToDegrees]("degrees"), expression[ToRadians]("radians"), - // misc functions - expression[Md5]("md5"), - expression[Sha2]("sha2"), - expression[Sha1]("sha1"), - expression[Sha1]("sha"), - expression[Crc32]("crc32"), - // aggregate functions expression[Average]("avg"), expression[Count]("count"), @@ -212,22 +205,41 @@ object FunctionRegistry { expression[Upper]("upper"), // datetime functions + expression[AddMonths]("add_months"), expression[CurrentDate]("current_date"), expression[CurrentTimestamp]("current_timestamp"), + expression[DateAdd]("date_add"), expression[DateFormatClass]("date_format"), + expression[DateSub]("date_sub"), expression[DayOfMonth]("day"), expression[DayOfYear]("dayofyear"), expression[DayOfMonth]("dayofmonth"), + expression[FromUnixTime]("from_unixtime"), expression[Hour]("hour"), - expression[Month]("month"), + expression[LastDay]("last_day"), expression[Minute]("minute"), + expression[Month]("month"), + expression[MonthsBetween]("months_between"), + expression[NextDay]("next_day"), expression[Quarter]("quarter"), expression[Second]("second"), + expression[ToDate]("to_date"), + expression[TruncDate]("trunc"), + expression[UnixTimestamp]("unix_timestamp"), expression[WeekOfYear]("weekofyear"), expression[Year]("year"), // collection functions - expression[Size]("size") + expression[Size]("size"), + + // misc functions + expression[Crc32]("crc32"), + expression[Md5]("md5"), + expression[Sha1]("sha"), + expression[Sha1]("sha1"), + expression[Sha2]("sha2"), + expression[SparkPartitionID]("spark_partition_id"), + expression[InputFileName]("input_file_name") ) val builtin: FunctionRegistry = { @@ -237,7 +249,7 @@ object FunctionRegistry { } /** See usage above. */ - private def expression[T <: Expression](name: String) + def expression[T <: Expression](name: String) (implicit tag: ClassTag[T]): (String, (ExpressionInfo, FunctionBuilder)) = { // See if we can find a constructor that accepts Seq[Expression] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index e0527503442f0..603afc4032a37 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -47,6 +47,7 @@ object HiveTypeCoercion { Division :: PropagateTypes :: ImplicitTypeCasts :: + DateTimeOperations :: Nil // See https://cwiki.apache.org/confluence/display/Hive/LanguageManual+Types. @@ -109,13 +110,35 @@ object HiveTypeCoercion { * Find the tightest common type of a set of types by continuously applying * `findTightestCommonTypeOfTwo` on these types. */ - private def findTightestCommonType(types: Seq[DataType]) = { + private def findTightestCommonType(types: Seq[DataType]): Option[DataType] = { types.foldLeft[Option[DataType]](Some(NullType))((r, c) => r match { case None => None case Some(d) => findTightestCommonTypeOfTwo(d, c) }) } + private def findWiderTypeForTwo(t1: DataType, t2: DataType): Option[DataType] = (t1, t2) match { + case (t1: DecimalType, t2: DecimalType) => + Some(DecimalPrecision.widerDecimalType(t1, t2)) + case (t: IntegralType, d: DecimalType) => + Some(DecimalPrecision.widerDecimalType(DecimalType.forType(t), d)) + case (d: DecimalType, t: IntegralType) => + Some(DecimalPrecision.widerDecimalType(DecimalType.forType(t), d)) + case (t: FractionalType, d: DecimalType) => + Some(DoubleType) + case (d: DecimalType, t: FractionalType) => + Some(DoubleType) + case _ => + findTightestCommonTypeToString(t1, t2) + } + + private def findWiderCommonType(types: Seq[DataType]) = { + types.foldLeft[Option[DataType]](Some(NullType))((r, c) => r match { + case Some(d) => findWiderTypeForTwo(d, c) + case None => None + }) + } + /** * Applies any changes to [[AttributeReference]] data types that are made by other rules to * instances higher in the query tree. @@ -182,20 +205,7 @@ object HiveTypeCoercion { val castedTypes = left.output.zip(right.output).map { case (lhs, rhs) if lhs.dataType != rhs.dataType => - (lhs.dataType, rhs.dataType) match { - case (t1: DecimalType, t2: DecimalType) => - Some(DecimalPrecision.widerDecimalType(t1, t2)) - case (t: IntegralType, d: DecimalType) => - Some(DecimalPrecision.widerDecimalType(DecimalType.forType(t), d)) - case (d: DecimalType, t: IntegralType) => - Some(DecimalPrecision.widerDecimalType(DecimalType.forType(t), d)) - case (t: FractionalType, d: DecimalType) => - Some(DoubleType) - case (d: DecimalType, t: FractionalType) => - Some(DoubleType) - case _ => - findTightestCommonTypeToString(lhs.dataType, rhs.dataType) - } + findWiderTypeForTwo(lhs.dataType, rhs.dataType) case other => None } @@ -236,8 +246,13 @@ object HiveTypeCoercion { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e - case a @ BinaryArithmetic(left @ StringType(), r) => - a.makeCopy(Array(Cast(left, DoubleType), r)) + case a @ BinaryArithmetic(left @ StringType(), right @ DecimalType.Expression(_, _)) => + a.makeCopy(Array(Cast(left, DecimalType.SYSTEM_DEFAULT), right)) + case a @ BinaryArithmetic(left @ DecimalType.Expression(_, _), right @ StringType()) => + a.makeCopy(Array(left, Cast(right, DecimalType.SYSTEM_DEFAULT))) + + case a @ BinaryArithmetic(left @ StringType(), right) => + a.makeCopy(Array(Cast(left, DoubleType), right)) case a @ BinaryArithmetic(left, right @ StringType()) => a.makeCopy(Array(left, Cast(right, DoubleType))) @@ -543,7 +558,7 @@ object HiveTypeCoercion { // compatible with every child column. case c @ Coalesce(es) if es.map(_.dataType).distinct.size > 1 => val types = es.map(_.dataType) - findTightestCommonTypeAndPromoteToString(types) match { + findWiderCommonType(types) match { case Some(finalDataType) => Coalesce(es.map(Cast(_, finalDataType))) case None => c } @@ -624,6 +639,27 @@ object HiveTypeCoercion { } } + /** + * Turns Add/Subtract of DateType/TimestampType/StringType and CalendarIntervalType + * to TimeAdd/TimeSub + */ + object DateTimeOperations extends Rule[LogicalPlan] { + + private val acceptedTypes = Seq(DateType, TimestampType, StringType) + + def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + // Skip nodes who's children have not been resolved yet. + case e if !e.childrenResolved => e + + case Add(l @ CalendarIntervalType(), r) if acceptedTypes.contains(r.dataType) => + Cast(TimeAdd(r, l), r.dataType) + case Add(l, r @ CalendarIntervalType()) if acceptedTypes.contains(l.dataType) => + Cast(TimeAdd(l, r), l.dataType) + case Subtract(l, r @ CalendarIntervalType()) if acceptedTypes.contains(l.dataType) => + Cast(TimeSub(l, r), l.dataType) + } + } + /** * Casts types according to the expected input types for [[Expression]]s. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index 41a877f214e55..45709c1c8f554 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -48,9 +48,9 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) case DoubleType => input.getDouble(ordinal) case StringType => input.getUTF8String(ordinal) case BinaryType => input.getBinary(ordinal) - case IntervalType => input.getInterval(ordinal) + case CalendarIntervalType => input.getInterval(ordinal) case t: StructType => input.getStruct(ordinal, t.size) - case dataType => input.get(ordinal, dataType) + case _ => input.get(ordinal, dataType) } } } @@ -64,10 +64,11 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) override def exprId: ExprId = throw new UnsupportedOperationException override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val javaType = ctx.javaType(dataType) + val value = ctx.getValue("i", dataType, ordinal.toString) s""" - boolean ${ev.isNull} = i.isNullAt($ordinal); - ${ctx.javaType(dataType)} ${ev.primitive} = ${ev.isNull} ? - ${ctx.defaultValue(dataType)} : (${ctx.getColumn("i", dataType, ordinal)}); + boolean ${ev.isNull} = i.isNullAt($ordinal); + $javaType ${ev.primitive} = ${ev.isNull} ? ${ctx.defaultValue(dataType)} : ($value); """ } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index bd8b0177eb00e..43be11c48ae7c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.{Interval, UTF8String} +import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} import scala.collection.mutable @@ -55,7 +55,7 @@ object Cast { case (_, DateType) => true - case (StringType, IntervalType) => true + case (StringType, CalendarIntervalType) => true case (StringType, _: NumericType) => true case (BooleanType, _: NumericType) => true @@ -225,7 +225,7 @@ case class Cast(child: Expression, dataType: DataType) // IntervalConverter private[this] def castToInterval(from: DataType): Any => Any = from match { case StringType => - buildCast[UTF8String](_, s => Interval.fromString(s.toString)) + buildCast[UTF8String](_, s => CalendarInterval.fromString(s.toString)) case _ => _ => null } @@ -363,7 +363,21 @@ case class Cast(child: Expression, dataType: DataType) private[this] def castArray(from: ArrayType, to: ArrayType): Any => Any = { val elementCast = cast(from.elementType, to.elementType) - buildCast[Seq[Any]](_, _.map(v => if (v == null) null else elementCast(v))) + // TODO: Could be faster? + buildCast[ArrayData](_, array => { + val length = array.numElements() + val values = new Array[Any](length) + var i = 0 + while (i < length) { + if (array.isNullAt(i)) { + values(i) = null + } else { + values(i) = elementCast(array.get(i)) + } + i += 1 + } + new GenericArrayData(values) + }) } private[this] def castMap(from: MapType, to: MapType): Any => Any = { @@ -398,7 +412,7 @@ case class Cast(child: Expression, dataType: DataType) case DateType => castToDate(from) case decimal: DecimalType => castToDecimal(from, decimal) case TimestampType => castToTimestamp(from) - case IntervalType => castToInterval(from) + case CalendarIntervalType => castToInterval(from) case BooleanType => castToBoolean(from) case ByteType => castToByte(from) case ShortType => castToShort(from) @@ -438,7 +452,7 @@ case class Cast(child: Expression, dataType: DataType) case DateType => castToDateCode(from, ctx) case decimal: DecimalType => castToDecimalCode(from, decimal) case TimestampType => castToTimestampCode(from, ctx) - case IntervalType => castToIntervalCode(from) + case CalendarIntervalType => castToIntervalCode(from) case BooleanType => castToBooleanCode(from) case ByteType => castToByteCode(from) case ShortType => castToShortCode(from) @@ -599,7 +613,7 @@ case class Cast(child: Expression, dataType: DataType) } """ case BooleanType => - (c, evPrim, evNull) => s"$evPrim = $c ? 1L : 0;" + (c, evPrim, evNull) => s"$evPrim = $c ? 1L : 0L;" case _: IntegralType => (c, evPrim, evNull) => s"$evPrim = ${longToTimeStampCode(c)};" case DateType => @@ -630,7 +644,7 @@ case class Cast(child: Expression, dataType: DataType) private[this] def castToIntervalCode(from: DataType): CastFunction = from match { case StringType => (c, evPrim, evNull) => - s"$evPrim = Interval.fromString($c.toString());" + s"$evPrim = CalendarInterval.fromString($c.toString());" } private[this] def decimalToTimestampCode(d: String): String = @@ -665,7 +679,7 @@ case class Cast(child: Expression, dataType: DataType) } """ case BooleanType => - (c, evPrim, evNull) => s"$evPrim = $c ? 1 : 0;" + (c, evPrim, evNull) => s"$evPrim = $c ? (byte) 1 : (byte) 0;" case DateType => (c, evPrim, evNull) => s"$evNull = true;" case TimestampType => @@ -687,7 +701,7 @@ case class Cast(child: Expression, dataType: DataType) } """ case BooleanType => - (c, evPrim, evNull) => s"$evPrim = $c ? 1 : 0;" + (c, evPrim, evNull) => s"$evPrim = $c ? (short) 1 : (short) 0;" case DateType => (c, evPrim, evNull) => s"$evNull = true;" case TimestampType => @@ -731,7 +745,7 @@ case class Cast(child: Expression, dataType: DataType) } """ case BooleanType => - (c, evPrim, evNull) => s"$evPrim = $c ? 1 : 0;" + (c, evPrim, evNull) => s"$evPrim = $c ? 1L : 0L;" case DateType => (c, evPrim, evNull) => s"$evNull = true;" case TimestampType => @@ -753,7 +767,7 @@ case class Cast(child: Expression, dataType: DataType) } """ case BooleanType => - (c, evPrim, evNull) => s"$evPrim = $c ? 1 : 0;" + (c, evPrim, evNull) => s"$evPrim = $c ? 1.0f : 0.0f;" case DateType => (c, evPrim, evNull) => s"$evNull = true;" case TimestampType => @@ -775,7 +789,7 @@ case class Cast(child: Expression, dataType: DataType) } """ case BooleanType => - (c, evPrim, evNull) => s"$evPrim = $c ? 1 : 0;" + (c, evPrim, evNull) => s"$evPrim = $c ? 1.0d : 0.0d;" case DateType => (c, evPrim, evNull) => s"$evNull = true;" case TimestampType => @@ -789,37 +803,36 @@ case class Cast(child: Expression, dataType: DataType) private[this] def castArrayCode( from: ArrayType, to: ArrayType, ctx: CodeGenContext): CastFunction = { val elementCast = nullSafeCastFunction(from.elementType, to.elementType, ctx) - - val arraySeqClass = classOf[mutable.ArraySeq[Any]].getName + val arrayClass = classOf[GenericArrayData].getName val fromElementNull = ctx.freshName("feNull") val fromElementPrim = ctx.freshName("fePrim") val toElementNull = ctx.freshName("teNull") val toElementPrim = ctx.freshName("tePrim") val size = ctx.freshName("n") val j = ctx.freshName("j") - val result = ctx.freshName("result") + val values = ctx.freshName("values") (c, evPrim, evNull) => s""" - final int $size = $c.size(); - final $arraySeqClass $result = new $arraySeqClass($size); + final int $size = $c.numElements(); + final Object[] $values = new Object[$size]; for (int $j = 0; $j < $size; $j ++) { - if ($c.apply($j) == null) { - $result.update($j, null); + if ($c.isNullAt($j)) { + $values[$j] = null; } else { boolean $fromElementNull = false; ${ctx.javaType(from.elementType)} $fromElementPrim = - (${ctx.boxedType(from.elementType)}) $c.apply($j); + ${ctx.getValue(c, from.elementType, j)}; ${castCode(ctx, fromElementPrim, fromElementNull, toElementPrim, toElementNull, to.elementType, elementCast)} if ($toElementNull) { - $result.update($j, null); + $values[$j] = null; } else { - $result.update($j, $toElementPrim); + $values[$j] = $toElementPrim; } } } - $evPrim = $result; + $evPrim = new $arrayClass($values); """ } @@ -891,7 +904,7 @@ case class Cast(child: Expression, dataType: DataType) $result.setNullAt($i); } else { $fromType $fromFieldPrim = - ${ctx.getColumn(tmpRow, from.fields(i).dataType, i)}; + ${ctx.getValue(tmpRow, from.fields(i).dataType, i.toString)}; ${castCode(ctx, fromFieldPrim, fromFieldNull, toFieldPrim, toFieldNull, to.fields(i).dataType, cast)} if ($toFieldNull) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index cb4c3f24b2721..8fc182607ce68 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -201,11 +201,9 @@ trait Nondeterministic extends Expression { private[this] var initialized = false - final def initialize(): Unit = { - if (!initialized) { - initInternal() - initialized = true - } + final def setInitialValues(): Unit = { + initInternal() + initialized = true } protected def initInternal(): Unit @@ -355,9 +353,9 @@ abstract class BinaryExpression extends Expression { * @param f accepts two variable names and returns Java code to compute the output. */ protected def defineCodeGen( - ctx: CodeGenContext, - ev: GeneratedExpressionCode, - f: (String, String) => String): String = { + ctx: CodeGenContext, + ev: GeneratedExpressionCode, + f: (String, String) => String): String = { nullSafeCodeGen(ctx, ev, (eval1, eval2) => { s"${ev.primitive} = ${f(eval1, eval2)};" }) @@ -372,9 +370,9 @@ abstract class BinaryExpression extends Expression { * and returns Java code to compute the output. */ protected def nullSafeCodeGen( - ctx: CodeGenContext, - ev: GeneratedExpressionCode, - f: (String, String) => String): String = { + ctx: CodeGenContext, + ev: GeneratedExpressionCode, + f: (String, String) => String): String = { val eval1 = left.gen(ctx) val eval2 = right.gen(ctx) val resultCode = f(eval1.primitive, eval2.primitive) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala new file mode 100644 index 0000000000000..1e74f716955e3 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.rdd.SqlNewHadoopRDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} +import org.apache.spark.sql.types.{DataType, StringType} +import org.apache.spark.unsafe.types.UTF8String + +/** + * Expression that returns the name of the current file being read in using [[SqlNewHadoopRDD]] + */ +case class InputFileName() extends LeafExpression with Nondeterministic { + + override def nullable: Boolean = true + + override def dataType: DataType = StringType + + override val prettyName = "INPUT_FILE_NAME" + + override protected def initInternal(): Unit = {} + + override protected def evalInternal(input: InternalRow): UTF8String = { + SqlNewHadoopRDD.getInputFileName() + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + ev.isNull = "false" + s"final ${ctx.javaType(dataType)} ${ev.primitive} = " + + "org.apache.spark.rdd.SqlNewHadoopRDD.getInputFileName();" + } + +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala similarity index 95% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala index eca36b3274420..291b7a5bc3af5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala @@ -15,11 +15,10 @@ * limitations under the License. */ -package org.apache.spark.sql.execution.expressions +package org.apache.spark.sql.catalyst.expressions import org.apache.spark.TaskContext import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Nondeterministic, LeafExpression} import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} import org.apache.spark.sql.types.{LongType, DataType} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index 27d6ff587ab71..7c7664e4c1a91 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjection, GenerateMutableProjection} -import org.apache.spark.sql.types.{StructType, DataType} +import org.apache.spark.sql.types.{Decimal, StructType, DataType} import org.apache.spark.unsafe.types.UTF8String /** @@ -32,7 +32,7 @@ class InterpretedProjection(expressions: Seq[Expression]) extends Projection { this(expressions.map(BindReferences.bindReference(_, inputSchema))) expressions.foreach(_.foreach { - case n: Nondeterministic => n.initialize() + case n: Nondeterministic => n.setInitialValues() case _ => }) @@ -63,7 +63,7 @@ case class InterpretedMutableProjection(expressions: Seq[Expression]) extends Mu this(expressions.map(BindReferences.bindReference(_, inputSchema))) expressions.foreach(_.foreach { - case n: Nondeterministic => n.initialize() + case n: Nondeterministic => n.setInitialValues() case _ => }) @@ -225,6 +225,11 @@ class JoinedRow extends InternalRow { override def getFloat(i: Int): Float = if (i < row1.numFields) row1.getFloat(i) else row2.getFloat(i - row1.numFields) + override def getDecimal(i: Int, precision: Int, scale: Int): Decimal = { + if (i < row1.numFields) row1.getDecimal(i, precision, scale) + else row2.getDecimal(i - row1.numFields, precision, scale) + } + override def getStruct(i: Int, numFields: Int): InternalRow = { if (i < row1.numFields) { row1.getStruct(i, numFields) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala index 3f436c0eb893c..9fe877f10fa08 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala @@ -17,7 +17,10 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} +import org.apache.spark.sql.types._ +import org.apache.spark.util.collection.unsafe.sort.PrefixComparators.DoublePrefixComparator abstract sealed class SortDirection case object Ascending extends SortDirection @@ -37,4 +40,43 @@ case class SortOrder(child: Expression, direction: SortDirection) override def nullable: Boolean = child.nullable override def toString: String = s"$child ${if (direction == Ascending) "ASC" else "DESC"}" + + def isAscending: Boolean = direction == Ascending +} + +/** + * An expression to generate a 64-bit long prefix used in sorting. + */ +case class SortPrefix(child: SortOrder) extends UnaryExpression { + + override def eval(input: InternalRow): Any = throw new UnsupportedOperationException + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val childCode = child.child.gen(ctx) + val input = childCode.primitive + val DoublePrefixCmp = classOf[DoublePrefixComparator].getName + + val (nullValue: Long, prefixCode: String) = child.child.dataType match { + case BooleanType => + (Long.MinValue, s"$input ? 1L : 0L") + case _: IntegralType => + (Long.MinValue, s"(long) $input") + case FloatType | DoubleType => + (DoublePrefixComparator.computePrefix(Double.NegativeInfinity), + s"$DoublePrefixCmp.computePrefix((double)$input)") + case StringType => (0L, s"$input.getPrefix()") + case _ => (0L, "0L") + } + + childCode.code + + s""" + |long ${ev.primitive} = ${nullValue}L; + |boolean ${ev.isNull} = false; + |if (!${childCode.isNull}) { + | ${ev.primitive} = $prefixCode; + |} + """.stripMargin + } + + override def dataType: DataType = LongType } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala similarity index 88% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala index 61ef079d89af5..4b1772a2deed5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala @@ -15,11 +15,10 @@ * limitations under the License. */ -package org.apache.spark.sql.execution.expressions +package org.apache.spark.sql.catalyst.expressions import org.apache.spark.TaskContext import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Nondeterministic, LeafExpression} import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} import org.apache.spark.sql.types.{IntegerType, DataType} @@ -27,7 +26,7 @@ import org.apache.spark.sql.types.{IntegerType, DataType} /** * Expression that returns the current partition id of the Spark task. */ -private[sql] case object SparkPartitionID extends LeafExpression with Nondeterministic { +private[sql] case class SparkPartitionID() extends LeafExpression with Nondeterministic { override def nullable: Boolean = false @@ -35,6 +34,8 @@ private[sql] case object SparkPartitionID extends LeafExpression with Nondetermi @transient private[this] var partitionId: Int = _ + override val prettyName = "SPARK_PARTITION_ID" + override protected def initInternal(): Unit = { partitionId = TaskContext.getPartitionId() } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala index 10bd19c8a840f..d08f553cefe8c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala @@ -42,7 +42,7 @@ private[sql] case object Partial extends AggregateMode private[sql] case object PartialMerge extends AggregateMode /** - * An [[AggregateFunction2]] with [[PartialMerge]] mode is used to merge aggregation buffers + * An [[AggregateFunction2]] with [[Final]] mode is used to merge aggregation buffers * containing intermediate results for this function and then generate final result. * This function updates the given aggregation buffer by merging multiple aggregation buffers. * When it has processed all input rows, the final result of this function is returned. @@ -50,7 +50,7 @@ private[sql] case object PartialMerge extends AggregateMode private[sql] case object Final extends AggregateMode /** - * An [[AggregateFunction2]] with [[Partial]] mode is used to evaluate this function directly + * An [[AggregateFunction2]] with [[Complete]] mode is used to evaluate this function directly * from original input rows without any partial aggregation. * This function updates the given aggregation buffer with the original input of this * function. When it has processed all input rows, the final result of this function is returned. @@ -103,9 +103,30 @@ abstract class AggregateFunction2 final override def foldable: Boolean = false /** - * The offset of this function's buffer in the underlying buffer shared with other functions. + * The offset of this function's start buffer value in the + * underlying shared mutable aggregation buffer. + * For example, we have two aggregate functions `avg(x)` and `avg(y)`, which share + * the same aggregation buffer. In this shared buffer, the position of the first + * buffer value of `avg(x)` will be 0 and the position of the first buffer value of `avg(y)` + * will be 2. + */ + var mutableBufferOffset: Int = 0 + + /** + * The offset of this function's start buffer value in the + * underlying shared input aggregation buffer. An input aggregation buffer is used + * when we merge two aggregation buffers and it is basically the immutable one + * (we merge an input aggregation buffer and a mutable aggregation buffer and + * then store the new buffer values to the mutable aggregation buffer). + * Usually, an input aggregation buffer also contain extra elements like grouping + * keys at the beginning. So, mutableBufferOffset and inputBufferOffset are often + * different. + * For example, we have a grouping expression `key``, and two aggregate functions + * `avg(x)` and `avg(y)`. In this shared input aggregation buffer, the position of the first + * buffer value of `avg(x)` will be 1 and the position of the first buffer value of `avg(y)` + * will be 3 (position 0 is used for the value of key`). */ - var bufferOffset: Int = 0 + var inputBufferOffset: Int = 0 /** The schema of the aggregation buffer. */ def bufferSchema: StructType @@ -176,7 +197,7 @@ abstract class AlgebraicAggregate extends AggregateFunction2 with Serializable w override def initialize(buffer: MutableRow): Unit = { var i = 0 while (i < bufferAttributes.size) { - buffer(i + bufferOffset) = initialValues(i).eval() + buffer(i + mutableBufferOffset) = initialValues(i).eval() i += 1 } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala new file mode 100644 index 0000000000000..4a43318a95490 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala @@ -0,0 +1,167 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.aggregate + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan} +import org.apache.spark.sql.types.{StructType, MapType, ArrayType} + +/** + * Utility functions used by the query planner to convert our plan to new aggregation code path. + */ +object Utils { + // Right now, we do not support complex types in the grouping key schema. + private def supportsGroupingKeySchema(aggregate: Aggregate): Boolean = { + val hasComplexTypes = aggregate.groupingExpressions.map(_.dataType).exists { + case array: ArrayType => true + case map: MapType => true + case struct: StructType => true + case _ => false + } + + !hasComplexTypes + } + + private def doConvert(plan: LogicalPlan): Option[Aggregate] = plan match { + case p: Aggregate if supportsGroupingKeySchema(p) => + val converted = p.transformExpressionsDown { + case expressions.Average(child) => + aggregate.AggregateExpression2( + aggregateFunction = aggregate.Average(child), + mode = aggregate.Complete, + isDistinct = false) + + case expressions.Count(child) => + aggregate.AggregateExpression2( + aggregateFunction = aggregate.Count(child), + mode = aggregate.Complete, + isDistinct = false) + + // We do not support multiple COUNT DISTINCT columns for now. + case expressions.CountDistinct(children) if children.length == 1 => + aggregate.AggregateExpression2( + aggregateFunction = aggregate.Count(children.head), + mode = aggregate.Complete, + isDistinct = true) + + case expressions.First(child) => + aggregate.AggregateExpression2( + aggregateFunction = aggregate.First(child), + mode = aggregate.Complete, + isDistinct = false) + + case expressions.Last(child) => + aggregate.AggregateExpression2( + aggregateFunction = aggregate.Last(child), + mode = aggregate.Complete, + isDistinct = false) + + case expressions.Max(child) => + aggregate.AggregateExpression2( + aggregateFunction = aggregate.Max(child), + mode = aggregate.Complete, + isDistinct = false) + + case expressions.Min(child) => + aggregate.AggregateExpression2( + aggregateFunction = aggregate.Min(child), + mode = aggregate.Complete, + isDistinct = false) + + case expressions.Sum(child) => + aggregate.AggregateExpression2( + aggregateFunction = aggregate.Sum(child), + mode = aggregate.Complete, + isDistinct = false) + + case expressions.SumDistinct(child) => + aggregate.AggregateExpression2( + aggregateFunction = aggregate.Sum(child), + mode = aggregate.Complete, + isDistinct = true) + } + // Check if there is any expressions.AggregateExpression1 left. + // If so, we cannot convert this plan. + val hasAggregateExpression1 = converted.aggregateExpressions.exists { expr => + // For every expressions, check if it contains AggregateExpression1. + expr.find { + case agg: expressions.AggregateExpression1 => true + case other => false + }.isDefined + } + + // Check if there are multiple distinct columns. + val aggregateExpressions = converted.aggregateExpressions.flatMap { expr => + expr.collect { + case agg: AggregateExpression2 => agg + } + }.toSet.toSeq + val functionsWithDistinct = aggregateExpressions.filter(_.isDistinct) + val hasMultipleDistinctColumnSets = + if (functionsWithDistinct.map(_.aggregateFunction.children).distinct.length > 1) { + true + } else { + false + } + + if (!hasAggregateExpression1 && !hasMultipleDistinctColumnSets) Some(converted) else None + + case other => None + } + + def checkInvalidAggregateFunction2(aggregate: Aggregate): Unit = { + // If the plan cannot be converted, we will do a final round check to see if the original + // logical.Aggregate contains both AggregateExpression1 and AggregateExpression2. If so, + // we need to throw an exception. + val aggregateFunction2s = aggregate.aggregateExpressions.flatMap { expr => + expr.collect { + case agg: AggregateExpression2 => agg.aggregateFunction + } + }.distinct + if (aggregateFunction2s.nonEmpty) { + // For functions implemented based on the new interface, prepare a list of function names. + val invalidFunctions = { + if (aggregateFunction2s.length > 1) { + s"${aggregateFunction2s.tail.map(_.nodeName).mkString(",")} " + + s"and ${aggregateFunction2s.head.nodeName} are" + } else { + s"${aggregateFunction2s.head.nodeName} is" + } + } + val errorMessage = + s"${invalidFunctions} implemented based on the new Aggregate Function " + + s"interface and it cannot be used with functions implemented based on " + + s"the old Aggregate Function interface." + throw new AnalysisException(errorMessage) + } + } + + def tryConvert(plan: LogicalPlan): Option[Aggregate] = plan match { + case p: Aggregate => + val converted = doConvert(p) + if (converted.isDefined) { + converted + } else { + checkInvalidAggregateFunction2(p) + None + } + case other => None + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala index 42343d4d8d79c..5d4b349b1597a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala @@ -404,7 +404,7 @@ case class Average(child: Expression) extends UnaryExpression with PartialAggreg // partialSum already increase the precision by 10 val castedSum = Cast(Sum(partialSum.toAttribute), partialSum.dataType) - val castedCount = Sum(partialCount.toAttribute) + val castedCount = Cast(Sum(partialCount.toAttribute), partialSum.dataType) SplitEvaluation( Cast(Divide(castedSum, castedCount), dataType), partialCount :: partialSum :: Nil) @@ -490,13 +490,13 @@ case class Sum(child: Expression) extends UnaryExpression with PartialAggregate1 case DecimalType.Fixed(_, _) => val partialSum = Alias(Sum(child), "PartialSum")() SplitEvaluation( - Cast(CombineSum(partialSum.toAttribute), dataType), + Cast(Sum(partialSum.toAttribute), dataType), partialSum :: Nil) case _ => val partialSum = Alias(Sum(child), "PartialSum")() SplitEvaluation( - CombineSum(partialSum.toAttribute), + Sum(partialSum.toAttribute), partialSum :: Nil) } } @@ -522,8 +522,7 @@ case class SumFunction(expr: Expression, base: AggregateExpression1) extends Agg private val sum = MutableLiteral(null, calcType) - private val addFunction = - Coalesce(Seq(Add(Coalesce(Seq(sum, zero)), Cast(expr, calcType)), sum, zero)) + private val addFunction = Coalesce(Seq(Add(Coalesce(Seq(sum, zero)), Cast(expr, calcType)), sum)) override def update(input: InternalRow): Unit = { sum.update(addFunction, input) @@ -538,67 +537,6 @@ case class SumFunction(expr: Expression, base: AggregateExpression1) extends Agg } } -/** - * Sum should satisfy 3 cases: - * 1) sum of all null values = zero - * 2) sum for table column with no data = null - * 3) sum of column with null and not null values = sum of not null values - * Require separate CombineSum Expression and function as it has to distinguish "No data" case - * versus "data equals null" case, while aggregating results and at each partial expression.i.e., - * Combining PartitionLevel InputData - * <-- null - * Zero <-- Zero <-- null - * - * <-- null <-- no data - * null <-- null <-- no data - */ -case class CombineSum(child: Expression) extends AggregateExpression1 { - def this() = this(null) - - override def children: Seq[Expression] = child :: Nil - override def nullable: Boolean = true - override def dataType: DataType = child.dataType - override def toString: String = s"CombineSum($child)" - override def newInstance(): CombineSumFunction = new CombineSumFunction(child, this) -} - -case class CombineSumFunction(expr: Expression, base: AggregateExpression1) - extends AggregateFunction1 { - - def this() = this(null, null) // Required for serialization. - - private val calcType = - expr.dataType match { - case DecimalType.Fixed(precision, scale) => - DecimalType.bounded(precision + 10, scale) - case _ => - expr.dataType - } - - private val zero = Cast(Literal(0), calcType) - - private val sum = MutableLiteral(null, calcType) - - private val addFunction = - Coalesce(Seq(Add(Coalesce(Seq(sum, zero)), Cast(expr, calcType)), sum, zero)) - - override def update(input: InternalRow): Unit = { - val result = expr.eval(input) - // partial sum result can be null only when no input rows present - if(result != null) { - sum.update(addFunction, input) - } - } - - override def eval(input: InternalRow): Any = { - expr.dataType match { - case DecimalType.Fixed(_, _) => - Cast(sum, dataType).eval(null) - case _ => sum.eval(null) - } - } -} - case class SumDistinct(child: Expression) extends UnaryExpression with PartialAggregate1 { def this() = this(null) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index b37f530ec6814..6f8f4dd230f12 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -21,7 +21,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.Interval +import org.apache.spark.unsafe.types.CalendarInterval case class UnaryMinus(child: Expression) extends UnaryExpression with ExpectsInputTypes { @@ -37,12 +37,12 @@ case class UnaryMinus(child: Expression) extends UnaryExpression with ExpectsInp override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = dataType match { case dt: DecimalType => defineCodeGen(ctx, ev, c => s"$c.unary_$$minus()") case dt: NumericType => defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dt)})(-($c))") - case dt: IntervalType => defineCodeGen(ctx, ev, c => s"$c.negate()") + case dt: CalendarIntervalType => defineCodeGen(ctx, ev, c => s"$c.negate()") } protected override def nullSafeEval(input: Any): Any = { - if (dataType.isInstanceOf[IntervalType]) { - input.asInstanceOf[Interval].negate() + if (dataType.isInstanceOf[CalendarIntervalType]) { + input.asInstanceOf[CalendarInterval].negate() } else { numeric.negate(input) } @@ -68,8 +68,7 @@ case class UnaryPositive(child: Expression) extends UnaryExpression with Expects @ExpressionDescription( usage = "_FUNC_(expr) - Returns the absolute value of the numeric value", extended = "> SELECT _FUNC_('-1');\n1") -case class Abs(child: Expression) - extends UnaryExpression with ExpectsInputTypes with CodegenFallback { +case class Abs(child: Expression) extends UnaryExpression with ExpectsInputTypes { override def inputTypes: Seq[AbstractDataType] = Seq(NumericType) @@ -122,8 +121,8 @@ case class Add(left: Expression, right: Expression) extends BinaryArithmetic { private lazy val numeric = TypeUtils.getNumeric(dataType) protected override def nullSafeEval(input1: Any, input2: Any): Any = { - if (dataType.isInstanceOf[IntervalType]) { - input1.asInstanceOf[Interval].add(input2.asInstanceOf[Interval]) + if (dataType.isInstanceOf[CalendarIntervalType]) { + input1.asInstanceOf[CalendarInterval].add(input2.asInstanceOf[CalendarInterval]) } else { numeric.plus(input1, input2) } @@ -135,7 +134,7 @@ case class Add(left: Expression, right: Expression) extends BinaryArithmetic { case ByteType | ShortType => defineCodeGen(ctx, ev, (eval1, eval2) => s"(${ctx.javaType(dataType)})($eval1 $symbol $eval2)") - case IntervalType => + case CalendarIntervalType => defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.add($eval2)") case _ => defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1 $symbol $eval2") @@ -151,8 +150,8 @@ case class Subtract(left: Expression, right: Expression) extends BinaryArithmeti private lazy val numeric = TypeUtils.getNumeric(dataType) protected override def nullSafeEval(input1: Any, input2: Any): Any = { - if (dataType.isInstanceOf[IntervalType]) { - input1.asInstanceOf[Interval].subtract(input2.asInstanceOf[Interval]) + if (dataType.isInstanceOf[CalendarIntervalType]) { + input1.asInstanceOf[CalendarInterval].subtract(input2.asInstanceOf[CalendarInterval]) } else { numeric.minus(input1, input2) } @@ -164,7 +163,7 @@ case class Subtract(left: Expression, right: Expression) extends BinaryArithmeti case ByteType | ShortType => defineCodeGen(ctx, ev, (eval1, eval2) => s"(${ctx.javaType(dataType)})($eval1 $symbol $eval2)") - case IntervalType => + case CalendarIntervalType => defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.subtract($eval2)") case _ => defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1 $symbol $eval2") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatter.scala index 2087cc7f109bc..c98182c96b165 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatter.scala @@ -18,8 +18,7 @@ package org.apache.spark.sql.catalyst.expressions.codegen /** - * An utility class that indents a block of code based on the curly braces. - * + * An utility class that indents a block of code based on the curly braces and parentheses. * This is used to prettify generated code when in debug mode (or exceptions). * * Written by Matei Zaharia. @@ -35,11 +34,12 @@ private class CodeFormatter { private var indentString = "" private def addLine(line: String): Unit = { - val indentChange = line.count(_ == '{') - line.count(_ == '}') + val indentChange = + line.count(c => "({".indexOf(c) >= 0) - line.count(c => ")}".indexOf(c) >= 0) val newIndentLevel = math.max(0, indentLevel + indentChange) // Lines starting with '}' should be de-indented even if they contain '{' after; // in addition, lines ending with ':' are typically labels - val thisLineIndent = if (line.startsWith("}") || line.endsWith(":")) { + val thisLineIndent = if (line.startsWith("}") || line.startsWith(")") || line.endsWith(":")) { " " * (indentSize * (indentLevel - 1)) } else { indentString diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 2f02c90b1d5b3..60e2863f7bbb0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -100,17 +100,19 @@ class CodeGenContext { } /** - * Returns the code to access a column in Row for a given DataType. + * Returns the code to access a value in `SpecializedGetters` for a given DataType. */ - def getColumn(row: String, dataType: DataType, ordinal: Int): String = { + def getValue(getter: String, dataType: DataType, ordinal: String): String = { val jt = javaType(dataType) dataType match { - case _ if isPrimitiveType(jt) => s"$row.get${primitiveTypeName(jt)}($ordinal)" - case StringType => s"$row.getUTF8String($ordinal)" - case BinaryType => s"$row.getBinary($ordinal)" - case IntervalType => s"$row.getInterval($ordinal)" - case t: StructType => s"$row.getStruct($ordinal, ${t.size})" - case _ => s"($jt)$row.get($ordinal)" + case _ if isPrimitiveType(jt) => s"$getter.get${primitiveTypeName(jt)}($ordinal)" + case t: DecimalType => s"$getter.getDecimal($ordinal, ${t.precision}, ${t.scale})" + case StringType => s"$getter.getUTF8String($ordinal)" + case BinaryType => s"$getter.getBinary($ordinal)" + case CalendarIntervalType => s"$getter.getInterval($ordinal)" + case t: StructType => s"$getter.getStruct($ordinal, ${t.size})" + case a: ArrayType => s"$getter.getArray($ordinal)" + case _ => s"($jt)$getter.get($ordinal)" // todo: remove generic getter. } } @@ -119,10 +121,10 @@ class CodeGenContext { */ def setColumn(row: String, dataType: DataType, ordinal: Int, value: String): String = { val jt = javaType(dataType) - if (isPrimitiveType(jt)) { - s"$row.set${primitiveTypeName(jt)}($ordinal, $value)" - } else { - s"$row.update($ordinal, $value)" + dataType match { + case _ if isPrimitiveType(jt) => s"$row.set${primitiveTypeName(jt)}($ordinal, $value)" + case t: DecimalType => s"$row.setDecimal($ordinal, $value, ${t.precision})" + case _ => s"$row.update($ordinal, $value)" } } @@ -150,10 +152,10 @@ class CodeGenContext { case dt: DecimalType => "Decimal" case BinaryType => "byte[]" case StringType => "UTF8String" - case IntervalType => "Interval" + case CalendarIntervalType => "CalendarInterval" case _: StructType => "InternalRow" - case _: ArrayType => s"scala.collection.Seq" - case _: MapType => s"scala.collection.Map" + case _: ArrayType => "ArrayData" + case _: MapType => "scala.collection.Map" case dt: OpenHashSetUDT if dt.elementType == IntegerType => classOf[IntegerHashSet].getName case dt: OpenHashSetUDT if dt.elementType == LongType => classOf[LongHashSet].getName case _ => "Object" @@ -214,7 +216,9 @@ class CodeGenContext { case dt: DataType if isPrimitiveType(dt) => s"($c1 > $c2 ? 1 : $c1 < $c2 ? -1 : 0)" case BinaryType => s"org.apache.spark.sql.catalyst.util.TypeUtils.compareBinary($c1, $c2)" case NullType => "0" - case other => s"$c1.compare($c2)" + case other if other.isInstanceOf[AtomicType] => s"$c1.compare($c2)" + case _ => throw new IllegalArgumentException( + "cannot generate compare code for un-comparable type") } /** @@ -293,7 +297,8 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin classOf[UnsafeRow].getName, classOf[UTF8String].getName, classOf[Decimal].getName, - classOf[Interval].getName + classOf[CalendarInterval].getName, + classOf[ArrayData].getName )) evaluator.setExtendedClass(classOf[GeneratedClass]) try { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala index 6b187f05604fd..3492d2c6189ed 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.expressions.codegen -import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.{Nondeterministic, Expression} /** * A trait that can be used to provide a fallback mode for expression code generation. @@ -25,6 +25,11 @@ import org.apache.spark.sql.catalyst.expressions.Expression trait CodegenFallback extends Expression { protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + foreach { + case n: Nondeterministic => n.setInitialValues() + case _ => + } + ctx.references += this val objectTerm = ctx.freshName("obj") s""" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 9d2161947b351..1d223986d9441 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -34,15 +34,74 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro private val StringWriter = classOf[UnsafeRowWriters.UTF8StringWriter].getName private val BinaryWriter = classOf[UnsafeRowWriters.BinaryWriter].getName private val IntervalWriter = classOf[UnsafeRowWriters.IntervalWriter].getName + private val StructWriter = classOf[UnsafeRowWriters.StructWriter].getName + private val CompactDecimalWriter = classOf[UnsafeRowWriters.CompactDecimalWriter].getName + private val DecimalWriter = classOf[UnsafeRowWriters.DecimalWriter].getName /** Returns true iff we support this data type. */ def canSupport(dataType: DataType): Boolean = dataType match { case t: AtomicType if !t.isInstanceOf[DecimalType] => true - case _: IntervalType => true + case _: CalendarIntervalType => true + case t: StructType => t.toSeq.forall(field => canSupport(field.dataType)) case NullType => true + case t: DecimalType => true case _ => false } + def genAdditionalSize(dt: DataType, ev: GeneratedExpressionCode): String = dt match { + case t: DecimalType if t.precision > Decimal.MAX_LONG_DIGITS => + s" + (${ev.isNull} ? 0 : $DecimalWriter.getSize(${ev.primitive}))" + case StringType => + s" + (${ev.isNull} ? 0 : $StringWriter.getSize(${ev.primitive}))" + case BinaryType => + s" + (${ev.isNull} ? 0 : $BinaryWriter.getSize(${ev.primitive}))" + case CalendarIntervalType => + s" + (${ev.isNull} ? 0 : 16)" + case _: StructType => + s" + (${ev.isNull} ? 0 : $StructWriter.getSize(${ev.primitive}))" + case _ => "" + } + + def genFieldWriter( + ctx: CodeGenContext, + fieldType: DataType, + ev: GeneratedExpressionCode, + primitive: String, + index: Int, + cursor: String): String = fieldType match { + case _ if ctx.isPrimitiveType(fieldType) => + s"${ctx.setColumn(primitive, fieldType, index, ev.primitive)}" + case t: DecimalType if t.precision <= Decimal.MAX_LONG_DIGITS => + s""" + // make sure Decimal object has the same scale as DecimalType + if (${ev.primitive}.changePrecision(${t.precision}, ${t.scale})) { + $CompactDecimalWriter.write($primitive, $index, $cursor, ${ev.primitive}); + } else { + $primitive.setNullAt($index); + } + """ + case t: DecimalType if t.precision > Decimal.MAX_LONG_DIGITS => + s""" + // make sure Decimal object has the same scale as DecimalType + if (${ev.primitive}.changePrecision(${t.precision}, ${t.scale})) { + $cursor += $DecimalWriter.write($primitive, $index, $cursor, ${ev.primitive}); + } else { + $primitive.setNullAt($index); + } + """ + case StringType => + s"$cursor += $StringWriter.write($primitive, $index, $cursor, ${ev.primitive})" + case BinaryType => + s"$cursor += $BinaryWriter.write($primitive, $index, $cursor, ${ev.primitive})" + case CalendarIntervalType => + s"$cursor += $IntervalWriter.write($primitive, $index, $cursor, ${ev.primitive})" + case t: StructType => + s"$cursor += $StructWriter.write($primitive, $index, $cursor, ${ev.primitive})" + case NullType => "" + case _ => + throw new UnsupportedOperationException(s"Not supported DataType: $fieldType") + } + /** * Generates the code to create an [[UnsafeRow]] object based on the input expressions. * @param ctx context for code generation @@ -55,41 +114,24 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val ret = ev.primitive ctx.addMutableState("UnsafeRow", ret, s"$ret = new UnsafeRow();") - val bufferTerm = ctx.freshName("buffer") - ctx.addMutableState("byte[]", bufferTerm, s"$bufferTerm = new byte[64];") - val cursorTerm = ctx.freshName("cursor") - val numBytesTerm = ctx.freshName("numBytes") + val buffer = ctx.freshName("buffer") + ctx.addMutableState("byte[]", buffer, s"$buffer = new byte[64];") + val cursor = ctx.freshName("cursor") + val numBytes = ctx.freshName("numBytes") - val exprs = expressions.map(_.gen(ctx)) + val exprs = expressions.map { e => e.dataType match { + case st: StructType => createCodeForStruct(ctx, e.gen(ctx), st) + case _ => e.gen(ctx) + }} val allExprs = exprs.map(_.code).mkString("\n") - val fixedSize = 8 * exprs.length + UnsafeRow.calculateBitSetWidthInBytes(exprs.length) - val additionalSize = expressions.zipWithIndex.map { case (e, i) => - e.dataType match { - case StringType => - s" + (${exprs(i).isNull} ? 0 : $StringWriter.getSize(${exprs(i).primitive}))" - case BinaryType => - s" + (${exprs(i).isNull} ? 0 : $BinaryWriter.getSize(${exprs(i).primitive}))" - case IntervalType => - s" + (${exprs(i).isNull} ? 0 : 16)" - case _ => "" - } + val fixedSize = 8 * exprs.length + UnsafeRow.calculateBitSetWidthInBytes(exprs.length) + val additionalSize = expressions.zipWithIndex.map { + case (e, i) => genAdditionalSize(e.dataType, exprs(i)) }.mkString("") val writers = expressions.zipWithIndex.map { case (e, i) => - val update = e.dataType match { - case dt if ctx.isPrimitiveType(dt) => - s"${ctx.setColumn(ret, dt, i, exprs(i).primitive)}" - case StringType => - s"$cursorTerm += $StringWriter.write($ret, $i, $cursorTerm, ${exprs(i).primitive})" - case BinaryType => - s"$cursorTerm += $BinaryWriter.write($ret, $i, $cursorTerm, ${exprs(i).primitive})" - case IntervalType => - s"$cursorTerm += $IntervalWriter.write($ret, $i, $cursorTerm, ${exprs(i).primitive})" - case NullType => "" - case _ => - throw new UnsupportedOperationException(s"Not supported DataType: ${e.dataType}") - } + val update = genFieldWriter(ctx, e.dataType, exprs(i), ret, i, cursor) s"""if (${exprs(i).isNull}) { $ret.setNullAt($i); } else { @@ -99,24 +141,115 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro s""" $allExprs - int $numBytesTerm = $fixedSize $additionalSize; - if ($numBytesTerm > $bufferTerm.length) { - $bufferTerm = new byte[$numBytesTerm]; + int $numBytes = $fixedSize $additionalSize; + if ($numBytes > $buffer.length) { + $buffer = new byte[$numBytes]; } $ret.pointTo( - $bufferTerm, + $buffer, org.apache.spark.unsafe.PlatformDependent.BYTE_ARRAY_OFFSET, ${expressions.size}, - $numBytesTerm); - int $cursorTerm = $fixedSize; - + $numBytes); + int $cursor = $fixedSize; $writers boolean ${ev.isNull} = false; """ } + /** + * Generates the Java code to convert a struct (backed by InternalRow) to UnsafeRow. + * + * This function also handles nested structs by recursively generating the code to do conversion. + * + * @param ctx code generation context + * @param input the input struct, identified by a [[GeneratedExpressionCode]] + * @param schema schema of the struct field + */ + // TODO: refactor createCode and this function to reduce code duplication. + private def createCodeForStruct( + ctx: CodeGenContext, + input: GeneratedExpressionCode, + schema: StructType): GeneratedExpressionCode = { + + val isNull = input.isNull + val primitive = ctx.freshName("structConvert") + ctx.addMutableState("UnsafeRow", primitive, s"$primitive = new UnsafeRow();") + val buffer = ctx.freshName("buffer") + ctx.addMutableState("byte[]", buffer, s"$buffer = new byte[64];") + val cursor = ctx.freshName("cursor") + + val exprs: Seq[GeneratedExpressionCode] = schema.map(_.dataType).zipWithIndex.map { + case (dt, i) => dt match { + case st: StructType => + val nestedStructEv = GeneratedExpressionCode( + code = "", + isNull = s"${input.primitive}.isNullAt($i)", + primitive = s"${ctx.getValue(input.primitive, dt, i.toString)}" + ) + createCodeForStruct(ctx, nestedStructEv, st) + case _ => + GeneratedExpressionCode( + code = "", + isNull = s"${input.primitive}.isNullAt($i)", + primitive = s"${ctx.getValue(input.primitive, dt, i.toString)}" + ) + } + } + val allExprs = exprs.map(_.code).mkString("\n") + + val fixedSize = 8 * exprs.length + UnsafeRow.calculateBitSetWidthInBytes(exprs.length) + val additionalSize = schema.toSeq.map(_.dataType).zip(exprs).map { case (dt, ev) => + genAdditionalSize(dt, ev) + }.mkString("") + + val writers = schema.toSeq.map(_.dataType).zip(exprs).zipWithIndex.map { case ((dt, ev), i) => + val update = genFieldWriter(ctx, dt, ev, primitive, i, cursor) + s""" + if (${exprs(i).isNull}) { + $primitive.setNullAt($i); + } else { + $update; + } + """ + }.mkString("\n ") + + // Note that we add a shortcut here for performance: if the input is already an UnsafeRow, + // just copy the bytes directly into our buffer space without running any conversion. + // We also had to use a hack to introduce a "tmp" variable, to avoid the Java compiler from + // complaining that a GenericMutableRow (generated by expressions) cannot be cast to UnsafeRow. + val tmp = ctx.freshName("tmp") + val numBytes = ctx.freshName("numBytes") + val code = s""" + |${input.code} + |if (!${input.isNull}) { + | Object $tmp = (Object) ${input.primitive}; + | if ($tmp instanceof UnsafeRow) { + | $primitive = (UnsafeRow) $tmp; + | } else { + | $allExprs + | + | int $numBytes = $fixedSize $additionalSize; + | if ($numBytes > $buffer.length) { + | $buffer = new byte[$numBytes]; + | } + | + | $primitive.pointTo( + | $buffer, + | org.apache.spark.unsafe.PlatformDependent.BYTE_ARRAY_OFFSET, + | ${exprs.size}, + | $numBytes); + | int $cursor = $fixedSize; + | + | $writers + | } + |} + """.stripMargin + + GeneratedExpressionCode(code, isNull, primitive) + } + protected def canonicalize(in: Seq[Expression]): Seq[Expression] = in.map(ExpressionCanonicalizer.execute) @@ -132,18 +265,18 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro eval.code = createCode(ctx, eval, expressions) val code = s""" - private $exprType[] expressions; - - public Object generate($exprType[] expr) { - this.expressions = expr; - return new SpecificProjection(); + public Object generate($exprType[] exprs) { + return new SpecificProjection(exprs); } class SpecificProjection extends ${classOf[UnsafeProjection].getName} { + private $exprType[] expressions; + ${declareMutableStates(ctx)} - public SpecificProjection() { + public SpecificProjection($exprType[] expressions) { + this.expressions = expressions; ${initMutableStates(ctx)} } @@ -159,7 +292,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro } """ - logDebug(s"code for ${expressions.mkString(",")}:\n$code") + logDebug(s"code for ${expressions.mkString(",")}:\n${CodeFormatter.format(code)}") val c = compile(code) c.generate(ctx.references.toArray).asInstanceOf[UnsafeProjection] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 2d92dcf23a86e..1a00dbc254de1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -27,11 +27,15 @@ case class Size(child: Expression) extends UnaryExpression with ExpectsInputType override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(ArrayType, MapType)) override def nullSafeEval(value: Any): Int = child.dataType match { - case ArrayType(_, _) => value.asInstanceOf[Seq[Any]].size - case MapType(_, _, _) => value.asInstanceOf[Map[Any, Any]].size + case _: ArrayType => value.asInstanceOf[ArrayData].numElements() + case _: MapType => value.asInstanceOf[Map[Any, Any]].size } override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - nullSafeCodeGen(ctx, ev, c => s"${ev.primitive} = ($c).size();") + val sizeCall = child.dataType match { + case _: ArrayType => "numElements()" + case _: MapType => "size()" + } + nullSafeCodeGen(ctx, ev, c => s"${ev.primitive} = ($c).$sizeCall;") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 119168fa59f15..a145dfb4bbf08 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -17,11 +17,10 @@ package org.apache.spark.sql.catalyst.expressions -import scala.collection.mutable - +import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjection, GeneratedExpressionCode, CodeGenContext} +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ @@ -44,25 +43,26 @@ case class CreateArray(children: Seq[Expression]) extends Expression { override def nullable: Boolean = false override def eval(input: InternalRow): Any = { - children.map(_.eval(input)) + new GenericArrayData(children.map(_.eval(input)).toArray) } override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val arraySeqClass = classOf[mutable.ArraySeq[Any]].getName + val arrayClass = classOf[GenericArrayData].getName s""" - boolean ${ev.isNull} = false; - $arraySeqClass ${ev.primitive} = new $arraySeqClass(${children.size}); + final boolean ${ev.isNull} = false; + final Object[] values = new Object[${children.size}]; """ + children.zipWithIndex.map { case (e, i) => val eval = e.gen(ctx) eval.code + s""" if (${eval.isNull}) { - ${ev.primitive}.update($i, null); + values[$i] = null; } else { - ${ev.primitive}.update($i, ${eval.primitive}); + values[$i] = ${eval.primitive}; } """ - }.mkString("\n") + }.mkString("\n") + + s"final ${ctx.javaType(dataType)} ${ev.primitive} = new $arrayClass(values);" } override def prettyName: String = "array" @@ -104,18 +104,19 @@ case class CreateStruct(children: Seq[Expression]) extends Expression { children.zipWithIndex.map { case (e, i) => val eval = e.gen(ctx) eval.code + s""" - if (${eval.isNull}) { - ${ev.primitive}.update($i, null); - } else { - ${ev.primitive}.update($i, ${eval.primitive}); - } - """ + if (${eval.isNull}) { + ${ev.primitive}.update($i, null); + } else { + ${ev.primitive}.update($i, ${eval.primitive}); + } + """ }.mkString("\n") } override def prettyName: String = "struct" } + /** * Creates a struct with the given field names and values * @@ -126,11 +127,12 @@ case class CreateNamedStruct(children: Seq[Expression]) extends Expression { private lazy val (nameExprs, valExprs) = children.grouped(2).map { case Seq(name, value) => (name, value) }.toList.unzip - private lazy val names = nameExprs.map(_.eval(EmptyRow).toString) + private lazy val names = nameExprs.map(_.eval(EmptyRow)) override lazy val dataType: StructType = { val fields = names.zip(valExprs).map { case (name, valExpr) => - StructField(name, valExpr.dataType, valExpr.nullable, Metadata.empty) + StructField(name.asInstanceOf[UTF8String].toString, + valExpr.dataType, valExpr.nullable, Metadata.empty) } StructType(fields) } @@ -143,14 +145,15 @@ case class CreateNamedStruct(children: Seq[Expression]) extends Expression { if (children.size % 2 != 0) { TypeCheckResult.TypeCheckFailure(s"$prettyName expects an even number of arguments.") } else { - val invalidNames = - nameExprs.filterNot(e => e.foldable && e.dataType == StringType && !nullable) + val invalidNames = nameExprs.filterNot(e => e.foldable && e.dataType == StringType) if (invalidNames.nonEmpty) { TypeCheckResult.TypeCheckFailure( - s"Odd position only allow foldable and not-null StringType expressions, got :" + + s"Only foldable StringType expressions are allowed to appear at odd position , got :" + s" ${invalidNames.mkString(",")}") - } else { + } else if (names.forall(_ != null)){ TypeCheckResult.TypeCheckSuccess + } else { + TypeCheckResult.TypeCheckFailure("Field name should not be null") } } } @@ -168,14 +171,83 @@ case class CreateNamedStruct(children: Seq[Expression]) extends Expression { valExprs.zipWithIndex.map { case (e, i) => val eval = e.gen(ctx) eval.code + s""" - if (${eval.isNull}) { - ${ev.primitive}.update($i, null); - } else { - ${ev.primitive}.update($i, ${eval.primitive}); - } - """ + if (${eval.isNull}) { + ${ev.primitive}.update($i, null); + } else { + ${ev.primitive}.update($i, ${eval.primitive}); + } + """ }.mkString("\n") } override def prettyName: String = "named_struct" } + +/** + * Returns a Row containing the evaluation of all children expressions. This is a variant that + * returns UnsafeRow directly. The unsafe projection operator replaces [[CreateStruct]] with + * this expression automatically at runtime. + */ +case class CreateStructUnsafe(children: Seq[Expression]) extends Expression { + + override def foldable: Boolean = children.forall(_.foldable) + + override lazy val resolved: Boolean = childrenResolved + + override lazy val dataType: StructType = { + val fields = children.zipWithIndex.map { case (child, idx) => + child match { + case ne: NamedExpression => + StructField(ne.name, ne.dataType, ne.nullable, ne.metadata) + case _ => + StructField(s"col${idx + 1}", child.dataType, child.nullable, Metadata.empty) + } + } + StructType(fields) + } + + override def nullable: Boolean = false + + override def eval(input: InternalRow): Any = throw new UnsupportedOperationException + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + GenerateUnsafeProjection.createCode(ctx, ev, children) + } + + override def prettyName: String = "struct_unsafe" +} + + +/** + * Creates a struct with the given field names and values. This is a variant that returns + * UnsafeRow directly. The unsafe projection operator replaces [[CreateStruct]] with + * this expression automatically at runtime. + * + * @param children Seq(name1, val1, name2, val2, ...) + */ +case class CreateNamedStructUnsafe(children: Seq[Expression]) extends Expression { + + private lazy val (nameExprs, valExprs) = + children.grouped(2).map { case Seq(name, value) => (name, value) }.toList.unzip + + private lazy val names = nameExprs.map(_.eval(EmptyRow).toString) + + override lazy val dataType: StructType = { + val fields = names.zip(valExprs).map { case (name, valExpr) => + StructField(name, valExpr.dataType, valExpr.nullable, Metadata.empty) + } + StructType(fields) + } + + override def foldable: Boolean = valExprs.forall(_.foldable) + + override def nullable: Boolean = false + + override def eval(input: InternalRow): Any = throw new UnsupportedOperationException + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + GenerateUnsafeProjection.createCode(ctx, ev, valExprs) + } + + override def prettyName: String = "named_struct_unsafe" +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index 6331a9eb603ca..99393c9c76ab6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -57,7 +57,8 @@ object ExtractValue { case (ArrayType(StructType(fields), containsNull), NonNullLiteral(v, StringType)) => val fieldName = v.toString val ordinal = findField(fields, fieldName, resolver) - GetArrayStructFields(child, fields(ordinal).copy(name = fieldName), ordinal, containsNull) + GetArrayStructFields(child, fields(ordinal).copy(name = fieldName), + ordinal, fields.length, containsNull) case (_: ArrayType, _) if extraction.dataType.isInstanceOf[IntegralType] => GetArrayItem(child, extraction) @@ -118,7 +119,7 @@ case class GetStructField(child: Expression, field: StructField, ordinal: Int) if ($eval.isNullAt($ordinal)) { ${ev.isNull} = true; } else { - ${ev.primitive} = ${ctx.getColumn(eval, dataType, ordinal)}; + ${ev.primitive} = ${ctx.getValue(eval, dataType, ordinal.toString)}; } """ }) @@ -134,6 +135,7 @@ case class GetArrayStructFields( child: Expression, field: StructField, ordinal: Int, + numFields: Int, containsNull: Boolean) extends UnaryExpression { override def dataType: DataType = ArrayType(field.dataType, containsNull) @@ -141,26 +143,45 @@ case class GetArrayStructFields( override def toString: String = s"$child.${field.name}" protected override def nullSafeEval(input: Any): Any = { - input.asInstanceOf[Seq[InternalRow]].map { row => - if (row == null) null else row.get(ordinal, field.dataType) + val array = input.asInstanceOf[ArrayData] + val length = array.numElements() + val result = new Array[Any](length) + var i = 0 + while (i < length) { + if (array.isNullAt(i)) { + result(i) = null + } else { + val row = array.getStruct(i, numFields) + if (row.isNullAt(ordinal)) { + result(i) = null + } else { + result(i) = row.get(ordinal, field.dataType) + } + } + i += 1 } + new GenericArrayData(result) } override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val arraySeqClass = "scala.collection.mutable.ArraySeq" - // TODO: consider using Array[_] for ArrayType child to avoid - // boxing of primitives + val arrayClass = classOf[GenericArrayData].getName nullSafeCodeGen(ctx, ev, eval => { s""" - final int n = $eval.size(); - final $arraySeqClass values = new $arraySeqClass(n); + final int n = $eval.numElements(); + final Object[] values = new Object[n]; for (int j = 0; j < n; j++) { - InternalRow row = (InternalRow) $eval.apply(j); - if (row != null && !row.isNullAt($ordinal)) { - values.update(j, ${ctx.getColumn("row", field.dataType, ordinal)}); + if ($eval.isNullAt(j)) { + values[j] = null; + } else { + final InternalRow row = $eval.getStruct(j, $numFields); + if (row.isNullAt($ordinal)) { + values[j] = null; + } else { + values[j] = ${ctx.getValue("row", field.dataType, ordinal.toString)}; + } } } - ${ev.primitive} = (${ctx.javaType(dataType)}) values; + ${ev.primitive} = new $arrayClass(values); """ }) } @@ -186,23 +207,23 @@ case class GetArrayItem(child: Expression, ordinal: Expression) extends BinaryEx protected override def nullSafeEval(value: Any, ordinal: Any): Any = { // TODO: consider using Array[_] for ArrayType child to avoid // boxing of primitives - val baseValue = value.asInstanceOf[Seq[_]] + val baseValue = value.asInstanceOf[ArrayData] val index = ordinal.asInstanceOf[Number].intValue() - if (index >= baseValue.size || index < 0) { + if (index >= baseValue.numElements() || index < 0) { null } else { - baseValue(index) + baseValue.get(index) } } override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { nullSafeCodeGen(ctx, ev, (eval1, eval2) => { s""" - final int index = (int)$eval2; - if (index >= $eval1.size() || index < 0) { + final int index = (int) $eval2; + if (index >= $eval1.numElements() || index < 0) { ${ev.isNull} = true; } else { - ${ev.primitive} = (${ctx.boxedType(dataType)})$eval1.apply(index); + ${ev.primitive} = ${ctx.getValue(eval1, dataType, "index")}; } """ }) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala index 15b33da884dcb..961b1d8616801 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala @@ -315,7 +315,6 @@ case class CaseKeyWhen(key: Expression, branches: Seq[Expression]) extends CaseW * It takes at least 2 parameters, and returns null iff all parameters are null. */ case class Least(children: Seq[Expression]) extends Expression { - require(children.length > 1, "LEAST requires at least 2 arguments, got " + children.length) override def nullable: Boolean = children.forall(_.nullable) override def foldable: Boolean = children.forall(_.foldable) @@ -323,7 +322,9 @@ case class Least(children: Seq[Expression]) extends Expression { private lazy val ordering = TypeUtils.getOrdering(dataType) override def checkInputDataTypes(): TypeCheckResult = { - if (children.map(_.dataType).distinct.count(_ != NullType) > 1) { + if (children.length <= 1) { + TypeCheckResult.TypeCheckFailure(s"LEAST requires at least 2 arguments") + } else if (children.map(_.dataType).distinct.count(_ != NullType) > 1) { TypeCheckResult.TypeCheckFailure( s"The expressions should all have the same type," + s" got LEAST (${children.map(_.dataType)}).") @@ -369,7 +370,6 @@ case class Least(children: Seq[Expression]) extends Expression { * It takes at least 2 parameters, and returns null iff all parameters are null. */ case class Greatest(children: Seq[Expression]) extends Expression { - require(children.length > 1, "GREATEST requires at least 2 arguments, got " + children.length) override def nullable: Boolean = children.forall(_.nullable) override def foldable: Boolean = children.forall(_.foldable) @@ -377,7 +377,9 @@ case class Greatest(children: Seq[Expression]) extends Expression { private lazy val ordering = TypeUtils.getOrdering(dataType) override def checkInputDataTypes(): TypeCheckResult = { - if (children.map(_.dataType).distinct.count(_ != NullType) > 1) { + if (children.length <= 1) { + TypeCheckResult.TypeCheckFailure(s"GREATEST requires at least 2 arguments") + } else if (children.map(_.dataType).distinct.count(_ != NullType) > 1) { TypeCheckResult.TypeCheckFailure( s"The expressions should all have the same type," + s" got GREATEST (${children.map(_.dataType)}).") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeFunctions.scala index 9e55f0546e123..6e7613340c032 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeFunctions.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.expressions -import java.sql.Date import java.text.SimpleDateFormat import java.util.{Calendar, TimeZone} @@ -26,7 +25,9 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} + +import scala.util.Try /** * Returns the current date at the start of query evaluation. @@ -62,6 +63,53 @@ case class CurrentTimestamp() extends LeafExpression with CodegenFallback { } } +/** + * Adds a number of days to startdate. + */ +case class DateAdd(startDate: Expression, days: Expression) + extends BinaryExpression with ImplicitCastInputTypes { + + override def left: Expression = startDate + override def right: Expression = days + + override def inputTypes: Seq[AbstractDataType] = Seq(DateType, IntegerType) + + override def dataType: DataType = DateType + + override def nullSafeEval(start: Any, d: Any): Any = { + start.asInstanceOf[Int] + d.asInstanceOf[Int] + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + nullSafeCodeGen(ctx, ev, (sd, d) => { + s"""${ev.primitive} = $sd + $d;""" + }) + } +} + +/** + * Subtracts a number of days to startdate. + */ +case class DateSub(startDate: Expression, days: Expression) + extends BinaryExpression with ImplicitCastInputTypes { + override def left: Expression = startDate + override def right: Expression = days + + override def inputTypes: Seq[AbstractDataType] = Seq(DateType, IntegerType) + + override def dataType: DataType = DateType + + override def nullSafeEval(start: Any, d: Any): Any = { + start.asInstanceOf[Int] - d.asInstanceOf[Int] + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + nullSafeCodeGen(ctx, ev, (sd, d) => { + s"""${ev.primitive} = $sd - $d;""" + }) + } +} + case class Hour(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType) @@ -74,9 +122,7 @@ case class Hour(child: Expression) extends UnaryExpression with ImplicitCastInpu override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") - defineCodeGen(ctx, ev, (c) => - s"""$dtu.getHours($c)""" - ) + defineCodeGen(ctx, ev, c => s"$dtu.getHours($c)") } } @@ -92,9 +138,7 @@ case class Minute(child: Expression) extends UnaryExpression with ImplicitCastIn override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") - defineCodeGen(ctx, ev, (c) => - s"""$dtu.getMinutes($c)""" - ) + defineCodeGen(ctx, ev, c => s"$dtu.getMinutes($c)") } } @@ -110,9 +154,7 @@ case class Second(child: Expression) extends UnaryExpression with ImplicitCastIn override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") - defineCodeGen(ctx, ev, (c) => - s"""$dtu.getSeconds($c)""" - ) + defineCodeGen(ctx, ev, c => s"$dtu.getSeconds($c)") } } @@ -128,9 +170,7 @@ case class DayOfYear(child: Expression) extends UnaryExpression with ImplicitCas override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") - defineCodeGen(ctx, ev, (c) => - s"""$dtu.getDayInYear($c)""" - ) + defineCodeGen(ctx, ev, c => s"$dtu.getDayInYear($c)") } } @@ -147,9 +187,7 @@ case class Year(child: Expression) extends UnaryExpression with ImplicitCastInpu override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") - defineCodeGen(ctx, ev, c => - s"""$dtu.getYear($c)""" - ) + defineCodeGen(ctx, ev, c => s"$dtu.getYear($c)") } } @@ -165,9 +203,7 @@ case class Quarter(child: Expression) extends UnaryExpression with ImplicitCastI override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") - defineCodeGen(ctx, ev, (c) => - s"""$dtu.getQuarter($c)""" - ) + defineCodeGen(ctx, ev, c => s"$dtu.getQuarter($c)") } } @@ -183,9 +219,7 @@ case class Month(child: Expression) extends UnaryExpression with ImplicitCastInp override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") - defineCodeGen(ctx, ev, (c) => - s"""$dtu.getMonth($c)""" - ) + defineCodeGen(ctx, ev, c => s"$dtu.getMonth($c)") } } @@ -201,9 +235,7 @@ case class DayOfMonth(child: Expression) extends UnaryExpression with ImplicitCa override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") - defineCodeGen(ctx, ev, (c) => - s"""$dtu.getDayOfMonth($c)""" - ) + defineCodeGen(ctx, ev, c => s"$dtu.getDayOfMonth($c)") } } @@ -226,7 +258,7 @@ case class WeekOfYear(child: Expression) extends UnaryExpression with ImplicitCa } override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - nullSafeCodeGen(ctx, ev, (time) => { + nullSafeCodeGen(ctx, ev, time => { val cal = classOf[Calendar].getName val c = ctx.freshName("cal") ctx.addMutableState(cal, c, @@ -250,18 +282,503 @@ case class DateFormatClass(left: Expression, right: Expression) extends BinaryEx override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType, StringType) - override def prettyName: String = "date_format" - override protected def nullSafeEval(timestamp: Any, format: Any): Any = { val sdf = new SimpleDateFormat(format.toString) - UTF8String.fromString(sdf.format(new Date(timestamp.asInstanceOf[Long] / 1000))) + UTF8String.fromString(sdf.format(new java.util.Date(timestamp.asInstanceOf[Long] / 1000))) } override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val sdf = classOf[SimpleDateFormat].getName defineCodeGen(ctx, ev, (timestamp, format) => { s"""UTF8String.fromString((new $sdf($format.toString())) - .format(new java.sql.Date($timestamp / 1000)))""" + .format(new java.util.Date($timestamp / 1000)))""" + }) + } + + override def prettyName: String = "date_format" +} + +/** + * Converts time string with given pattern + * (see [http://docs.oracle.com/javase/tutorial/i18n/format/simpleDateFormat.html]) + * to Unix time stamp (in seconds), returns null if fail. + * Note that hive Language Manual says it returns 0 if fail, but in fact it returns null. + * If the second parameter is missing, use "yyyy-MM-dd HH:mm:ss". + * If no parameters provided, the first parameter will be current_timestamp. + * If the first parameter is a Date or Timestamp instead of String, we will ignore the + * second parameter. + */ +case class UnixTimestamp(timeExp: Expression, format: Expression) + extends BinaryExpression with ExpectsInputTypes { + + override def left: Expression = timeExp + override def right: Expression = format + + def this(time: Expression) = { + this(time, Literal("yyyy-MM-dd HH:mm:ss")) + } + + def this() = { + this(CurrentTimestamp()) + } + + override def inputTypes: Seq[AbstractDataType] = + Seq(TypeCollection(StringType, DateType, TimestampType), StringType) + + override def dataType: DataType = LongType + + private lazy val constFormat: UTF8String = right.eval().asInstanceOf[UTF8String] + + override def eval(input: InternalRow): Any = { + val t = left.eval(input) + if (t == null) { + null + } else { + left.dataType match { + case DateType => + DateTimeUtils.daysToMillis(t.asInstanceOf[Int]) / 1000L + case TimestampType => + t.asInstanceOf[Long] / 1000000L + case StringType if right.foldable => + if (constFormat != null) { + Try(new SimpleDateFormat(constFormat.toString).parse( + t.asInstanceOf[UTF8String].toString).getTime / 1000L).getOrElse(null) + } else { + null + } + case StringType => + val f = format.eval(input) + if (f == null) { + null + } else { + val formatString = f.asInstanceOf[UTF8String].toString + Try(new SimpleDateFormat(formatString).parse( + t.asInstanceOf[UTF8String].toString).getTime / 1000L).getOrElse(null) + } + } + } + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + left.dataType match { + case StringType if right.foldable => + val sdf = classOf[SimpleDateFormat].getName + val fString = if (constFormat == null) null else constFormat.toString + val formatter = ctx.freshName("formatter") + if (fString == null) { + s""" + boolean ${ev.isNull} = true; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + """ + } else { + val eval1 = left.gen(ctx) + s""" + ${eval1.code} + boolean ${ev.isNull} = ${eval1.isNull}; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + try { + $sdf $formatter = new $sdf("$fString"); + ${ev.primitive} = + $formatter.parse(${eval1.primitive}.toString()).getTime() / 1000L; + } catch (java.lang.Throwable e) { + ${ev.isNull} = true; + } + } + """ + } + case StringType => + val sdf = classOf[SimpleDateFormat].getName + nullSafeCodeGen(ctx, ev, (string, format) => { + s""" + try { + ${ev.primitive} = + (new $sdf($format.toString())).parse($string.toString()).getTime() / 1000L; + } catch (java.lang.Throwable e) { + ${ev.isNull} = true; + } + """ + }) + case TimestampType => + val eval1 = left.gen(ctx) + s""" + ${eval1.code} + boolean ${ev.isNull} = ${eval1.isNull}; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + ${ev.primitive} = ${eval1.primitive} / 1000000L; + } + """ + case DateType => + val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") + val eval1 = left.gen(ctx) + s""" + ${eval1.code} + boolean ${ev.isNull} = ${eval1.isNull}; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + ${ev.primitive} = $dtu.daysToMillis(${eval1.primitive}) / 1000L; + } + """ + } + } +} + +/** + * Converts the number of seconds from unix epoch (1970-01-01 00:00:00 UTC) to a string + * representing the timestamp of that moment in the current system time zone in the given + * format. If the format is missing, using format like "1970-01-01 00:00:00". + * Note that hive Language Manual says it returns 0 if fail, but in fact it returns null. + */ +case class FromUnixTime(sec: Expression, format: Expression) + extends BinaryExpression with ImplicitCastInputTypes { + + override def left: Expression = sec + override def right: Expression = format + + def this(unix: Expression) = { + this(unix, Literal("yyyy-MM-dd HH:mm:ss")) + } + + override def dataType: DataType = StringType + + override def inputTypes: Seq[AbstractDataType] = Seq(LongType, StringType) + + private lazy val constFormat: UTF8String = right.eval().asInstanceOf[UTF8String] + + override def eval(input: InternalRow): Any = { + val time = left.eval(input) + if (time == null) { + null + } else { + if (format.foldable) { + if (constFormat == null) { + null + } else { + Try(UTF8String.fromString(new SimpleDateFormat(constFormat.toString).format( + new java.util.Date(time.asInstanceOf[Long] * 1000L)))).getOrElse(null) + } + } else { + val f = format.eval(input) + if (f == null) { + null + } else { + Try(UTF8String.fromString(new SimpleDateFormat( + f.asInstanceOf[UTF8String].toString).format(new java.util.Date( + time.asInstanceOf[Long] * 1000L)))).getOrElse(null) + } + } + } + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val sdf = classOf[SimpleDateFormat].getName + if (format.foldable) { + if (constFormat == null) { + s""" + boolean ${ev.isNull} = true; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + """ + } else { + val t = left.gen(ctx) + s""" + ${t.code} + boolean ${ev.isNull} = ${t.isNull}; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + try { + ${ev.primitive} = UTF8String.fromString(new $sdf("${constFormat.toString}").format( + new java.util.Date(${t.primitive} * 1000L))); + } catch (java.lang.Throwable e) { + ${ev.isNull} = true; + } + } + """ + } + } else { + nullSafeCodeGen(ctx, ev, (seconds, f) => { + s""" + try { + ${ev.primitive} = UTF8String.fromString((new $sdf($f.toString())).format( + new java.util.Date($seconds * 1000L))); + } catch (java.lang.Throwable e) { + ${ev.isNull} = true; + }""".stripMargin + }) + } + } +} + +/** + * Returns the last day of the month which the date belongs to. + */ +case class LastDay(startDate: Expression) extends UnaryExpression with ImplicitCastInputTypes { + override def child: Expression = startDate + + override def inputTypes: Seq[AbstractDataType] = Seq(DateType) + + override def dataType: DataType = DateType + + override def nullSafeEval(date: Any): Any = { + DateTimeUtils.getLastDayOfMonth(date.asInstanceOf[Int]) + } + + override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") + defineCodeGen(ctx, ev, sd => s"$dtu.getLastDayOfMonth($sd)") + } + + override def prettyName: String = "last_day" +} + +/** + * Returns the first date which is later than startDate and named as dayOfWeek. + * For example, NextDay(2015-07-27, Sunday) would return 2015-08-02, which is the first + * Sunday later than 2015-07-27. + * + * Allowed "dayOfWeek" is defined in [[DateTimeUtils.getDayOfWeekFromString]]. + */ +case class NextDay(startDate: Expression, dayOfWeek: Expression) + extends BinaryExpression with ImplicitCastInputTypes { + + override def left: Expression = startDate + override def right: Expression = dayOfWeek + + override def inputTypes: Seq[AbstractDataType] = Seq(DateType, StringType) + + override def dataType: DataType = DateType + + override def nullSafeEval(start: Any, dayOfW: Any): Any = { + val dow = DateTimeUtils.getDayOfWeekFromString(dayOfW.asInstanceOf[UTF8String]) + if (dow == -1) { + null + } else { + val sd = start.asInstanceOf[Int] + DateTimeUtils.getNextDateForDayOfWeek(sd, dow) + } + } + + override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + nullSafeCodeGen(ctx, ev, (sd, dowS) => { + val dateTimeUtilClass = DateTimeUtils.getClass.getName.stripSuffix("$") + val dayOfWeekTerm = ctx.freshName("dayOfWeek") + if (dayOfWeek.foldable) { + val input = dayOfWeek.eval().asInstanceOf[UTF8String] + if ((input eq null) || DateTimeUtils.getDayOfWeekFromString(input) == -1) { + s""" + |${ev.isNull} = true; + """.stripMargin + } else { + val dayOfWeekValue = DateTimeUtils.getDayOfWeekFromString(input) + s""" + |${ev.primitive} = $dateTimeUtilClass.getNextDateForDayOfWeek($sd, $dayOfWeekValue); + """.stripMargin + } + } else { + s""" + |int $dayOfWeekTerm = $dateTimeUtilClass.getDayOfWeekFromString($dowS); + |if ($dayOfWeekTerm == -1) { + | ${ev.isNull} = true; + |} else { + | ${ev.primitive} = $dateTimeUtilClass.getNextDateForDayOfWeek($sd, $dayOfWeekTerm); + |} + """.stripMargin + } }) } + + override def prettyName: String = "next_day" +} + +/** + * Adds an interval to timestamp. + */ +case class TimeAdd(start: Expression, interval: Expression) + extends BinaryExpression with ImplicitCastInputTypes { + + override def left: Expression = start + override def right: Expression = interval + + override def toString: String = s"$left + $right" + override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType, CalendarIntervalType) + + override def dataType: DataType = TimestampType + + override def nullSafeEval(start: Any, interval: Any): Any = { + val itvl = interval.asInstanceOf[CalendarInterval] + DateTimeUtils.timestampAddInterval( + start.asInstanceOf[Long], itvl.months, itvl.microseconds) + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") + defineCodeGen(ctx, ev, (sd, i) => { + s"""$dtu.timestampAddInterval($sd, $i.months, $i.microseconds)""" + }) + } +} + +/** + * Subtracts an interval from timestamp. + */ +case class TimeSub(start: Expression, interval: Expression) + extends BinaryExpression with ImplicitCastInputTypes { + + override def left: Expression = start + override def right: Expression = interval + + override def toString: String = s"$left - $right" + override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType, CalendarIntervalType) + + override def dataType: DataType = TimestampType + + override def nullSafeEval(start: Any, interval: Any): Any = { + val itvl = interval.asInstanceOf[CalendarInterval] + DateTimeUtils.timestampAddInterval( + start.asInstanceOf[Long], 0 - itvl.months, 0 - itvl.microseconds) + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") + defineCodeGen(ctx, ev, (sd, i) => { + s"""$dtu.timestampAddInterval($sd, 0 - $i.months, 0 - $i.microseconds)""" + }) + } +} + +/** + * Returns the date that is num_months after start_date. + */ +case class AddMonths(startDate: Expression, numMonths: Expression) + extends BinaryExpression with ImplicitCastInputTypes { + + override def left: Expression = startDate + override def right: Expression = numMonths + + override def inputTypes: Seq[AbstractDataType] = Seq(DateType, IntegerType) + + override def dataType: DataType = DateType + + override def nullSafeEval(start: Any, months: Any): Any = { + DateTimeUtils.dateAddMonths(start.asInstanceOf[Int], months.asInstanceOf[Int]) + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") + defineCodeGen(ctx, ev, (sd, m) => { + s"""$dtu.dateAddMonths($sd, $m)""" + }) + } +} + +/** + * Returns number of months between dates date1 and date2. + */ +case class MonthsBetween(date1: Expression, date2: Expression) + extends BinaryExpression with ImplicitCastInputTypes { + + override def left: Expression = date1 + override def right: Expression = date2 + + override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType, TimestampType) + + override def dataType: DataType = DoubleType + + override def nullSafeEval(t1: Any, t2: Any): Any = { + DateTimeUtils.monthsBetween(t1.asInstanceOf[Long], t2.asInstanceOf[Long]) + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") + defineCodeGen(ctx, ev, (l, r) => { + s"""$dtu.monthsBetween($l, $r)""" + }) + } +} + +/** + * Returns the date part of a timestamp or string. + */ +case class ToDate(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { + + // Implicit casting of spark will accept string in both date and timestamp format, as + // well as TimestampType. + override def inputTypes: Seq[AbstractDataType] = Seq(DateType) + + override def dataType: DataType = DateType + + override def eval(input: InternalRow): Any = child.eval(input) + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + defineCodeGen(ctx, ev, d => d) + } +} + +/* + * Returns date truncated to the unit specified by the format. + */ +case class TruncDate(date: Expression, format: Expression) + extends BinaryExpression with ImplicitCastInputTypes { + override def left: Expression = date + override def right: Expression = format + + override def inputTypes: Seq[AbstractDataType] = Seq(DateType, StringType) + override def dataType: DataType = DateType + override def prettyName: String = "trunc" + + lazy val minItemConst = DateTimeUtils.parseTruncLevel(format.eval().asInstanceOf[UTF8String]) + + override def eval(input: InternalRow): Any = { + val minItem = if (format.foldable) { + minItemConst + } else { + DateTimeUtils.parseTruncLevel(format.eval().asInstanceOf[UTF8String]) + } + if (minItem == -1) { + // unknown format + null + } else { + val d = date.eval(input) + if (d == null) { + null + } else { + DateTimeUtils.truncDate(d.asInstanceOf[Int], minItem) + } + } + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") + + if (format.foldable) { + if (minItemConst == -1) { + s""" + boolean ${ev.isNull} = true; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + """ + } else { + val d = date.gen(ctx) + s""" + ${d.code} + boolean ${ev.isNull} = ${d.isNull}; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + ${ev.primitive} = $dtu.truncDate(${d.primitive}, $minItemConst); + } + """ + } + } else { + nullSafeCodeGen(ctx, ev, (dateVal, fmt) => { + val form = ctx.freshName("form") + s""" + int $form = $dtu.parseTruncLevel($fmt); + if ($form == -1) { + ${ev.isNull} = true; + } else { + ${ev.primitive} = $dtu.truncDate($dateVal, $form); + } + """ + }) + } + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index 2dbcf2830f876..8064235c64ef9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -121,8 +121,8 @@ case class Explode(child: Expression) extends UnaryExpression with Generator wit override def eval(input: InternalRow): TraversableOnce[InternalRow] = { child.dataType match { case ArrayType(_, _) => - val inputArray = child.eval(input).asInstanceOf[Seq[Any]] - if (inputArray == null) Nil else inputArray.map(v => InternalRow(v)) + val inputArray = child.eval(input).asInstanceOf[ArrayData] + if (inputArray == null) Nil else inputArray.toArray().map(v => InternalRow(v)) case MapType(_, _, _) => val inputMap = child.eval(input).asInstanceOf[Map[Any, Any]] if (inputMap == null) Nil diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index 85060b7893556..34bad23802ba4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -42,7 +42,7 @@ object Literal { case t: Timestamp => Literal(DateTimeUtils.fromJavaTimestamp(t), TimestampType) case d: Date => Literal(DateTimeUtils.fromJavaDate(d), DateType) case a: Array[Byte] => Literal(a, BinaryType) - case i: Interval => Literal(i, IntervalType) + case i: CalendarInterval => Literal(i, CalendarIntervalType) case null => Literal(null, NullType) case _ => throw new RuntimeException("Unsupported literal type " + v.getClass + " " + v) @@ -118,7 +118,7 @@ case class Literal protected (value: Any, dataType: DataType) super.genCode(ctx, ev) } else { ev.isNull = "false" - ev.primitive = s"${value}" + ev.primitive = s"${value}D" "" } case ByteType | ShortType => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala index 68cca0ad3d067..e6d807f6d897b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala @@ -646,19 +646,19 @@ case class Logarithm(left: Expression, right: Expression) /** * Round the `child`'s result to `scale` decimal place when `scale` >= 0 * or round at integral part when `scale` < 0. - * For example, round(31.415, 2) would eval to 31.42 and round(31.415, -1) would eval to 30. + * For example, round(31.415, 2) = 31.42 and round(31.415, -1) = 30. * - * Child of IntegralType would eval to itself when `scale` >= 0. - * Child of FractionalType whose value is NaN or Infinite would always eval to itself. + * Child of IntegralType would round to itself when `scale` >= 0. + * Child of FractionalType whose value is NaN or Infinite would always round to itself. * - * Round's dataType would always equal to `child`'s dataType except for [[DecimalType.Fixed]], - * which leads to scale update in DecimalType's [[PrecisionInfo]] + * Round's dataType would always equal to `child`'s dataType except for DecimalType, + * which would lead scale decrease from the origin DecimalType. * * @param child expr to be round, all [[NumericType]] is allowed as Input * @param scale new scale to be round to, this should be a constant int at runtime */ case class Round(child: Expression, scale: Expression) - extends BinaryExpression with ExpectsInputTypes { + extends BinaryExpression with ImplicitCastInputTypes { import BigDecimal.RoundingMode.HALF_UP @@ -838,6 +838,4 @@ case class Round(child: Expression, scale: Expression) """ } } - - override def prettyName: String = "round" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 5bfe1cad24a3e..ab7d3afce8f2e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -31,7 +31,7 @@ object InterpretedPredicate { def create(expression: Expression): (InternalRow => Boolean) = { expression.foreach { - case n: Nondeterministic => n.initialize() + case n: Nondeterministic => n.setInitialValues() case _ => } (r: InternalRow) => expression.eval(r).asInstanceOf[Boolean] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala index 8f30519697a37..62d3d204ca872 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala @@ -66,7 +66,7 @@ case class Rand(seed: Long) extends RDG { val rngTerm = ctx.freshName("rng") val className = classOf[XORShiftRandom].getName ctx.addMutableState(className, rngTerm, - s"$rngTerm = new $className($seed + org.apache.spark.TaskContext.getPartitionId());") + s"$rngTerm = new $className(${seed}L + org.apache.spark.TaskContext.getPartitionId());") ev.isNull = "false" s""" final ${ctx.javaType(dataType)} ${ev.primitive} = $rngTerm.nextDouble(); @@ -89,7 +89,7 @@ case class Randn(seed: Long) extends RDG { val rngTerm = ctx.freshName("rng") val className = classOf[XORShiftRandom].getName ctx.addMutableState(className, rngTerm, - s"$rngTerm = new $className($seed + org.apache.spark.TaskContext.getPartitionId());") + s"$rngTerm = new $className(${seed}L + org.apache.spark.TaskContext.getPartitionId());") ev.isNull = "false" s""" final ${ctx.javaType(dataType)} ${ev.primitive} = $rngTerm.nextGaussian(); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala index b7c4ece4a16fe..df6ea586c87ba 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.types.{DataType, StructType, AtomicType} +import org.apache.spark.sql.types.{Decimal, DataType, StructType, AtomicType} import org.apache.spark.unsafe.types.UTF8String /** @@ -39,6 +39,7 @@ abstract class MutableRow extends InternalRow { def setShort(i: Int, value: Short): Unit = { update(i, value) } def setByte(i: Int, value: Byte): Unit = { update(i, value) } def setFloat(i: Int, value: Float): Unit = { update(i, value) } + def setDecimal(i: Int, value: Decimal, precision: Int) { update(i, value) } def setString(i: Int, value: String): Unit = { update(i, UTF8String.fromString(value)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index 38b0fb37dee3b..79c0ca56a8e79 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -22,7 +22,6 @@ import java.util.Locale import java.util.regex.{MatchResult, Pattern} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.UnresolvedException import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -52,7 +51,7 @@ case class Concat(children: Seq[Expression]) extends Expression with ImplicitCas override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val evals = children.map(_.gen(ctx)) val inputs = evals.map { eval => - s"${eval.isNull} ? (UTF8String)null : ${eval.primitive}" + s"${eval.isNull} ? null : ${eval.primitive}" }.mkString(", ") evals.map(_.code).mkString("\n") + s""" boolean ${ev.isNull} = false; @@ -93,7 +92,7 @@ case class ConcatWs(children: Seq[Expression]) val flatInputs = children.flatMap { child => child.eval(input) match { case s: UTF8String => Iterator(s) - case arr: Seq[_] => arr.asInstanceOf[Seq[UTF8String]] + case arr: ArrayData => arr.toArray().map(_.asInstanceOf[UTF8String]) case null => Iterator(null.asInstanceOf[UTF8String]) } } @@ -106,7 +105,7 @@ case class ConcatWs(children: Seq[Expression]) val evals = children.map(_.gen(ctx)) val inputs = evals.map { eval => - s"${eval.isNull} ? (UTF8String)null : ${eval.primitive}" + s"${eval.isNull} ? (UTF8String) null : ${eval.primitive}" }.mkString(", ") evals.map(_.code).mkString("\n") + s""" @@ -666,13 +665,15 @@ case class StringSplit(str: Expression, pattern: Expression) override def inputTypes: Seq[DataType] = Seq(StringType, StringType) override def nullSafeEval(string: Any, regex: Any): Any = { - string.asInstanceOf[UTF8String].split(regex.asInstanceOf[UTF8String], -1).toSeq + val strings = string.asInstanceOf[UTF8String].split(regex.asInstanceOf[UTF8String], -1) + new GenericArrayData(strings.asInstanceOf[Array[Any]]) } override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val arrayClass = classOf[GenericArrayData].getName nullSafeCodeGen(ctx, ev, (str, pattern) => - s"""${ev.primitive} = scala.collection.JavaConversions.asScalaBuffer( - java.util.Arrays.asList($str.split($pattern, -1)));""") + // Array in java is covariant, so we don't need to cast UTF8String[] to Object[]. + s"""${ev.primitive} = new $arrayClass($str.split($pattern, -1));""") } override def prettyName: String = "split" @@ -777,7 +778,6 @@ case class Levenshtein(left: Expression, right: Expression) extends BinaryExpres override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType) override def dataType: DataType = IntegerType - protected override def nullSafeEval(leftValue: Any, rightValue: Any): Any = leftValue.asInstanceOf[UTF8String].levenshteinDistance(rightValue.asInstanceOf[UTF8String]) @@ -1009,7 +1009,7 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio s""" ${evalSubject.code} - boolean ${ev.isNull} = ${evalSubject.isNull}; + boolean ${ev.isNull} = true; ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; if (!${evalSubject.isNull}) { ${evalRegexp.code} @@ -1104,9 +1104,9 @@ case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expressio val evalIdx = idx.gen(ctx) s""" - ${ctx.javaType(dataType)} ${ev.primitive} = null; - boolean ${ev.isNull} = true; ${evalSubject.code} + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + boolean ${ev.isNull} = true; if (!${evalSubject.isNull}) { ${evalRegexp.code} if (!${evalRegexp.isNull}) { @@ -1118,7 +1118,7 @@ case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expressio ${termPattern} = ${classNamePattern}.compile(${termLastRegex}.toString()); } ${classOf[java.util.regex.Matcher].getCanonicalName} m = - ${termPattern}.matcher(${evalSubject.primitive}.toString()); + ${termPattern}.matcher(${evalSubject.primitive}.toString()); if (m.find()) { ${classOf[java.util.regex.MatchResult].getCanonicalName} mr = m.toMatchResult(); ${ev.primitive} = ${classNameUTF8String}.fromString(mr.group(${evalIdx.primitive})); @@ -1140,7 +1140,7 @@ case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expressio * fractional part. */ case class FormatNumber(x: Expression, d: Expression) - extends BinaryExpression with ExpectsInputTypes with CodegenFallback { + extends BinaryExpression with ExpectsInputTypes { override def left: Expression = x override def right: Expression = d diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 813c62009666c..29d706dcb39a7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -312,7 +312,8 @@ object NullPropagation extends Rule[LogicalPlan] { case e @ GetMapValue(Literal(null, _), _) => Literal.create(null, e.dataType) case e @ GetMapValue(_, Literal(null, _)) => Literal.create(null, e.dataType) case e @ GetStructField(Literal(null, _), _, _) => Literal.create(null, e.dataType) - case e @ GetArrayStructFields(Literal(null, _), _, _, _) => Literal.create(null, e.dataType) + case e @ GetArrayStructFields(Literal(null, _), _, _, _, _) => + Literal.create(null, e.dataType) case e @ EqualNullSafe(Literal(null, _), r) => IsNull(r) case e @ EqualNullSafe(l, Literal(null, _)) => IsNull(l) case e @ Count(expr) if !expr.nullable => Count(Literal(1)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index af68358daf5f1..a67f8de6b733a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.Utils import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.types._ import org.apache.spark.util.collection.OpenHashSet @@ -33,7 +34,7 @@ case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extend }.nonEmpty ) - !expressions.exists(!_.resolved) && childrenResolved && !hasSpecialExpressions + expressions.forall(_.resolved) && childrenResolved && !hasSpecialExpressions } } @@ -67,7 +68,7 @@ case class Generate( generator.resolved && childrenResolved && generator.elementTypes.length == generatorOutput.length && - !generatorOutput.exists(!_.resolved) + generatorOutput.forall(_.resolved) } // we don't want the gOutput to be taken as part of the expressions @@ -187,7 +188,7 @@ case class WithWindowDefinition( } /** - * @param order The ordering expressions + * @param order The ordering expressions, should all be [[AttributeReference]] * @param global True means global sorting apply for entire data set, * False means sorting only apply within the partition. * @param child Child logical plan @@ -197,6 +198,11 @@ case class Sort( global: Boolean, child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output + + def hasNoEvaluation: Boolean = order.forall(_.child.isInstanceOf[AttributeReference]) + + override lazy val resolved: Boolean = + expressions.forall(_.resolved) && childrenResolved && hasNoEvaluation } case class Aggregate( @@ -211,9 +217,11 @@ case class Aggregate( }.nonEmpty ) - !expressions.exists(!_.resolved) && childrenResolved && !hasWindowExpressions + expressions.forall(_.resolved) && childrenResolved && !hasWindowExpressions } + lazy val newAggregation: Option[Aggregate] = Utils.tryConvert(this) + override def output: Seq[Attribute] = aggregateExpressions.map(_.toAttribute) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 2dcfa19fec383..f4d1dbaf28efe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -86,14 +86,6 @@ sealed trait Partitioning { */ def satisfies(required: Distribution): Boolean - /** - * Returns true iff all distribution guarantees made by this partitioning can also be made - * for the `other` specified partitioning. - * For example, two [[HashPartitioning HashPartitioning]]s are - * only compatible if the `numPartitions` of them is the same. - */ - def compatibleWith(other: Partitioning): Boolean - /** Returns the expressions that are used to key the partitioning. */ def keyExpressions: Seq[Expression] } @@ -104,11 +96,6 @@ case class UnknownPartitioning(numPartitions: Int) extends Partitioning { case _ => false } - override def compatibleWith(other: Partitioning): Boolean = other match { - case UnknownPartitioning(_) => true - case _ => false - } - override def keyExpressions: Seq[Expression] = Nil } @@ -117,11 +104,6 @@ case object SinglePartition extends Partitioning { override def satisfies(required: Distribution): Boolean = true - override def compatibleWith(other: Partitioning): Boolean = other match { - case SinglePartition => true - case _ => false - } - override def keyExpressions: Seq[Expression] = Nil } @@ -130,11 +112,6 @@ case object BroadcastPartitioning extends Partitioning { override def satisfies(required: Distribution): Boolean = true - override def compatibleWith(other: Partitioning): Boolean = other match { - case SinglePartition => true - case _ => false - } - override def keyExpressions: Seq[Expression] = Nil } @@ -159,12 +136,6 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) case _ => false } - override def compatibleWith(other: Partitioning): Boolean = other match { - case BroadcastPartitioning => true - case h: HashPartitioning if h == this => true - case _ => false - } - override def keyExpressions: Seq[Expression] = expressions } @@ -199,11 +170,5 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int) case _ => false } - override def compatibleWith(other: Partitioning): Boolean = other match { - case BroadcastPartitioning => true - case r: RangePartitioning if r == this => true - case _ => false - } - override def keyExpressions: Seq[Expression] = ordering.map(_.child) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index 07412e73b6a5b..5a7c25b8d508d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -45,6 +45,7 @@ object DateTimeUtils { final val to2001 = -11323 // this is year -17999, calculation: 50 * daysIn400Year + final val YearZero = -17999 final val toYearZero = to2001 + 7304850 @transient lazy val defaultTimeZone = TimeZone.getDefault @@ -573,4 +574,243 @@ object DateTimeUtils { dayInYear - 334 } } + + /** + * The number of days for each month (not leap year) + */ + private val monthDays = Array(31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31) + + /** + * Returns the date value for the first day of the given month. + * The month is expressed in months since year zero (17999 BC), starting from 0. + */ + private def firstDayOfMonth(absoluteMonth: Int): Int = { + val absoluteYear = absoluteMonth / 12 + var monthInYear = absoluteMonth - absoluteYear * 12 + var date = getDateFromYear(absoluteYear) + if (monthInYear >= 2 && isLeapYear(absoluteYear + YearZero)) { + date += 1 + } + while (monthInYear > 0) { + date += monthDays(monthInYear - 1) + monthInYear -= 1 + } + date + } + + /** + * Returns the date value for January 1 of the given year. + * The year is expressed in years since year zero (17999 BC), starting from 0. + */ + private def getDateFromYear(absoluteYear: Int): Int = { + val absoluteDays = (absoluteYear * 365 + absoluteYear / 400 - absoluteYear / 100 + + absoluteYear / 4) + absoluteDays - toYearZero + } + + /** + * Add date and year-month interval. + * Returns a date value, expressed in days since 1.1.1970. + */ + def dateAddMonths(days: Int, months: Int): Int = { + val absoluteMonth = (getYear(days) - YearZero) * 12 + getMonth(days) - 1 + months + val currentMonthInYear = absoluteMonth % 12 + val currentYear = absoluteMonth / 12 + val leapDay = if (currentMonthInYear == 1 && isLeapYear(currentYear + YearZero)) 1 else 0 + val lastDayOfMonth = monthDays(currentMonthInYear) + leapDay + + val dayOfMonth = getDayOfMonth(days) + val currentDayInMonth = if (getDayOfMonth(days + 1) == 1 || dayOfMonth >= lastDayOfMonth) { + // last day of the month + lastDayOfMonth + } else { + dayOfMonth + } + firstDayOfMonth(absoluteMonth) + currentDayInMonth - 1 + } + + /** + * Add timestamp and full interval. + * Returns a timestamp value, expressed in microseconds since 1.1.1970 00:00:00. + */ + def timestampAddInterval(start: Long, months: Int, microseconds: Long): Long = { + val days = millisToDays(start / 1000L) + val newDays = dateAddMonths(days, months) + daysToMillis(newDays) * 1000L + start - daysToMillis(days) * 1000L + microseconds + } + + /** + * Returns the last dayInMonth in the month it belongs to. The date is expressed + * in days since 1.1.1970. the return value starts from 1. + */ + private def getLastDayInMonthOfMonth(date: Int): Int = { + var (year, dayInYear) = getYearAndDayInYear(date) + if (isLeapYear(year)) { + if (dayInYear > 31 && dayInYear <= 60) { + return 29 + } else if (dayInYear > 60) { + dayInYear = dayInYear - 1 + } + } + if (dayInYear <= 31) { + 31 + } else if (dayInYear <= 59) { + 28 + } else if (dayInYear <= 90) { + 31 + } else if (dayInYear <= 120) { + 30 + } else if (dayInYear <= 151) { + 31 + } else if (dayInYear <= 181) { + 30 + } else if (dayInYear <= 212) { + 31 + } else if (dayInYear <= 243) { + 31 + } else if (dayInYear <= 273) { + 30 + } else if (dayInYear <= 304) { + 31 + } else if (dayInYear <= 334) { + 30 + } else { + 31 + } + } + + /** + * Returns number of months between time1 and time2. time1 and time2 are expressed in + * microseconds since 1.1.1970. + * + * If time1 and time2 having the same day of month, or both are the last day of month, + * it returns an integer (time under a day will be ignored). + * + * Otherwise, the difference is calculated based on 31 days per month, and rounding to + * 8 digits. + */ + def monthsBetween(time1: Long, time2: Long): Double = { + val millis1 = time1 / 1000L + val millis2 = time2 / 1000L + val date1 = millisToDays(millis1) + val date2 = millisToDays(millis2) + // TODO(davies): get year, month, dayOfMonth from single function + val dayInMonth1 = getDayOfMonth(date1) + val dayInMonth2 = getDayOfMonth(date2) + val months1 = getYear(date1) * 12 + getMonth(date1) + val months2 = getYear(date2) * 12 + getMonth(date2) + + if (dayInMonth1 == dayInMonth2 || (dayInMonth1 == getLastDayInMonthOfMonth(date1) + && dayInMonth2 == getLastDayInMonthOfMonth(date2))) { + return (months1 - months2).toDouble + } + // milliseconds is enough for 8 digits precision on the right side + val timeInDay1 = millis1 - daysToMillis(date1) + val timeInDay2 = millis2 - daysToMillis(date2) + val timesBetween = (timeInDay1 - timeInDay2).toDouble / MILLIS_PER_DAY + val diff = (months1 - months2).toDouble + (dayInMonth1 - dayInMonth2 + timesBetween) / 31.0 + // rounding to 8 digits + math.round(diff * 1e8) / 1e8 + } + + /* + * Returns day of week from String. Starting from Thursday, marked as 0. + * (Because 1970-01-01 is Thursday). + */ + def getDayOfWeekFromString(string: UTF8String): Int = { + val dowString = string.toString.toUpperCase + dowString match { + case "SU" | "SUN" | "SUNDAY" => 3 + case "MO" | "MON" | "MONDAY" => 4 + case "TU" | "TUE" | "TUESDAY" => 5 + case "WE" | "WED" | "WEDNESDAY" => 6 + case "TH" | "THU" | "THURSDAY" => 0 + case "FR" | "FRI" | "FRIDAY" => 1 + case "SA" | "SAT" | "SATURDAY" => 2 + case _ => -1 + } + } + + /** + * Returns the first date which is later than startDate and is of the given dayOfWeek. + * dayOfWeek is an integer ranges in [0, 6], and 0 is Thu, 1 is Fri, etc,. + */ + def getNextDateForDayOfWeek(startDate: Int, dayOfWeek: Int): Int = { + startDate + 1 + ((dayOfWeek - 1 - startDate) % 7 + 7) % 7 + } + + /** + * Returns last day of the month for the given date. The date is expressed in days + * since 1.1.1970. + */ + def getLastDayOfMonth(date: Int): Int = { + var (year, dayInYear) = getYearAndDayInYear(date) + if (isLeapYear(year)) { + if (dayInYear > 31 && dayInYear <= 60) { + return date + (60 - dayInYear) + } else if (dayInYear > 60) { + dayInYear = dayInYear - 1 + } + } + val lastDayOfMonthInYear = if (dayInYear <= 31) { + 31 + } else if (dayInYear <= 59) { + 59 + } else if (dayInYear <= 90) { + 90 + } else if (dayInYear <= 120) { + 120 + } else if (dayInYear <= 151) { + 151 + } else if (dayInYear <= 181) { + 181 + } else if (dayInYear <= 212) { + 212 + } else if (dayInYear <= 243) { + 243 + } else if (dayInYear <= 273) { + 273 + } else if (dayInYear <= 304) { + 304 + } else if (dayInYear <= 334) { + 334 + } else { + 365 + } + date + (lastDayOfMonthInYear - dayInYear) + } + + private val TRUNC_TO_YEAR = 1 + private val TRUNC_TO_MONTH = 2 + private val TRUNC_INVALID = -1 + + /** + * Returns the trunc date from original date and trunc level. + * Trunc level should be generated using `parseTruncLevel()`, should only be 1 or 2. + */ + def truncDate(d: Int, level: Int): Int = { + if (level == TRUNC_TO_YEAR) { + d - DateTimeUtils.getDayInYear(d) + 1 + } else if (level == TRUNC_TO_MONTH) { + d - DateTimeUtils.getDayOfMonth(d) + 1 + } else { + throw new Exception(s"Invalid trunc level: $level") + } + } + + /** + * Returns the truncate level, could be TRUNC_YEAR, TRUNC_MONTH, or TRUNC_INVALID, + * TRUNC_INVALID means unsupported truncate level. + */ + def parseTruncLevel(format: UTF8String): Int = { + if (format == null) { + TRUNC_INVALID + } else { + format.toString.toUpperCase match { + case "YEAR" | "YYYY" | "YY" => TRUNC_TO_YEAR + case "MON" | "MONTH" | "MM" => TRUNC_TO_MONTH + case _ => TRUNC_INVALID + } + } + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala index 40bf4b299c990..e0667c629486d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala @@ -95,7 +95,7 @@ private[sql] object TypeCollection { * Types that include numeric types and interval type. They are only used in unary_minus, * unary_positive, add and subtract operations. */ - val NumericAndInterval = TypeCollection(NumericType, IntervalType) + val NumericAndInterval = TypeCollection(NumericType, CalendarIntervalType) def apply(types: AbstractDataType*): TypeCollection = new TypeCollection(types) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayData.scala new file mode 100644 index 0000000000000..14a7285877622 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayData.scala @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.types + +import org.apache.spark.sql.catalyst.expressions.SpecializedGetters + +abstract class ArrayData extends SpecializedGetters with Serializable { + // todo: remove this after we handle all types.(map type need special getter) + def get(ordinal: Int): Any + + def numElements(): Int + + // todo: need a more efficient way to iterate array type. + def toArray(): Array[Any] = { + val n = numElements() + val values = new Array[Any](n) + var i = 0 + while (i < n) { + if (isNullAt(i)) { + values(i) = null + } else { + values(i) = get(i) + } + i += 1 + } + values + } + + override def toString(): String = toArray.mkString("[", ",", "]") + + override def equals(o: Any): Boolean = { + if (!o.isInstanceOf[ArrayData]) { + return false + } + + val other = o.asInstanceOf[ArrayData] + if (other eq null) { + return false + } + + val len = numElements() + if (len != other.numElements()) { + return false + } + + var i = 0 + while (i < len) { + if (isNullAt(i) != other.isNullAt(i)) { + return false + } + if (!isNullAt(i)) { + val o1 = get(i) + val o2 = other.get(i) + o1 match { + case b1: Array[Byte] => + if (!o2.isInstanceOf[Array[Byte]] || + !java.util.Arrays.equals(b1, o2.asInstanceOf[Array[Byte]])) { + return false + } + case f1: Float if java.lang.Float.isNaN(f1) => + if (!o2.isInstanceOf[Float] || ! java.lang.Float.isNaN(o2.asInstanceOf[Float])) { + return false + } + case d1: Double if java.lang.Double.isNaN(d1) => + if (!o2.isInstanceOf[Double] || ! java.lang.Double.isNaN(o2.asInstanceOf[Double])) { + return false + } + case _ => if (o1 != o2) { + return false + } + } + } + i += 1 + } + true + } + + override def hashCode: Int = { + var result: Int = 37 + var i = 0 + val len = numElements() + while (i < len) { + val update: Int = + if (isNullAt(i)) { + 0 + } else { + get(i) match { + case b: Boolean => if (b) 0 else 1 + case b: Byte => b.toInt + case s: Short => s.toInt + case i: Int => i + case l: Long => (l ^ (l >>> 32)).toInt + case f: Float => java.lang.Float.floatToIntBits(f) + case d: Double => + val b = java.lang.Double.doubleToLongBits(d) + (b ^ (b >>> 32)).toInt + case a: Array[Byte] => java.util.Arrays.hashCode(a) + case other => other.hashCode() + } + } + result = 37 * result + update + i += 1 + } + result + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/IntervalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/CalendarIntervalType.scala similarity index 64% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/types/IntervalType.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/types/CalendarIntervalType.scala index 87c6e9e6e5e2c..3565f52c21f69 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/IntervalType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/CalendarIntervalType.scala @@ -22,16 +22,19 @@ import org.apache.spark.annotation.DeveloperApi /** * :: DeveloperApi :: - * The data type representing time intervals. + * The data type representing calendar time intervals. The calendar time interval is stored + * internally in two components: number of months the number of microseconds. * - * Please use the singleton [[DataTypes.IntervalType]]. + * Note that calendar intervals are not comparable. + * + * Please use the singleton [[DataTypes.CalendarIntervalType]]. */ @DeveloperApi -class IntervalType private() extends DataType { +class CalendarIntervalType private() extends DataType { - override def defaultSize: Int = 4096 + override def defaultSize: Int = 16 - private[spark] override def asNullable: IntervalType = this + private[spark] override def asNullable: CalendarIntervalType = this } -case object IntervalType extends IntervalType +case object CalendarIntervalType extends CalendarIntervalType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index 591fb26e67c4a..f4428c2e8b202 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -142,12 +142,21 @@ object DataType { ("type", JString("struct"))) => StructType(fields.map(parseStructField)) + // Scala/Java UDT case JSortedObject( ("class", JString(udtClass)), ("pyClass", _), ("sqlType", _), ("type", JString("udt"))) => Utils.classForName(udtClass).newInstance().asInstanceOf[UserDefinedType[_]] + + // Python UDT + case JSortedObject( + ("pyClass", JString(pyClass)), + ("serializedClass", JString(serialized)), + ("sqlType", v: JValue), + ("type", JString("udt"))) => + new PythonUserDefinedType(parseDataType(v), pyClass, serialized) } private def parseStructField(json: JValue): StructField = json match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala index bc689810bc292..c0155eeb450a6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -188,6 +188,10 @@ final class Decimal extends Ordered[Decimal] with Serializable { * @return true if successful, false if overflow would occur */ def changePrecision(precision: Int, scale: Int): Boolean = { + // fast path for UnsafeProjection + if (precision == this.precision && scale == this.scale) { + return true + } // First, update our longVal if we can, or transfer over to using a BigDecimal if (decimalVal.eq(null)) { if (scale < _scale) { @@ -224,7 +228,7 @@ final class Decimal extends Ordered[Decimal] with Serializable { decimalVal = newVal } else { // We're still using Longs, but we should check whether we match the new precision - val p = POW_10(math.min(_precision, MAX_LONG_DIGITS)) + val p = POW_10(math.min(precision, MAX_LONG_DIGITS)) if (longVal <= -p || longVal >= p) { // Note that we shouldn't have been able to fix this by switching to BigDecimal return false diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/GenericArrayData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/GenericArrayData.scala new file mode 100644 index 0000000000000..35ace673fb3da --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/GenericArrayData.scala @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.types + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.unsafe.types.{UTF8String, CalendarInterval} + +class GenericArrayData(array: Array[Any]) extends ArrayData { + private def getAs[T](ordinal: Int) = get(ordinal).asInstanceOf[T] + + override def toArray(): Array[Any] = array + + override def get(ordinal: Int): Any = array(ordinal) + + override def isNullAt(ordinal: Int): Boolean = get(ordinal) == null + + override def getBoolean(ordinal: Int): Boolean = getAs(ordinal) + + override def getByte(ordinal: Int): Byte = getAs(ordinal) + + override def getShort(ordinal: Int): Short = getAs(ordinal) + + override def getInt(ordinal: Int): Int = getAs(ordinal) + + override def getLong(ordinal: Int): Long = getAs(ordinal) + + override def getFloat(ordinal: Int): Float = getAs(ordinal) + + override def getDouble(ordinal: Int): Double = getAs(ordinal) + + override def getDecimal(ordinal: Int, precision: Int, scale: Int): Decimal = getAs(ordinal) + + override def getUTF8String(ordinal: Int): UTF8String = getAs(ordinal) + + override def getBinary(ordinal: Int): Array[Byte] = getAs(ordinal) + + override def getInterval(ordinal: Int): CalendarInterval = getAs(ordinal) + + override def getStruct(ordinal: Int, numFields: Int): InternalRow = getAs(ordinal) + + override def getArray(ordinal: Int): ArrayData = getAs(ordinal) + + override def numElements(): Int = array.length +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala index e47cfb4833bd8..4305903616bd9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala @@ -45,6 +45,9 @@ abstract class UserDefinedType[UserType] extends DataType with Serializable { /** Paired Python UDT class, if exists. */ def pyUDT: String = null + /** Serialized Python UDT class, if exists. */ + def serializedPyClass: String = null + /** * Convert the user type to a SQL datum * @@ -82,3 +85,29 @@ abstract class UserDefinedType[UserType] extends DataType with Serializable { override private[sql] def acceptsType(dataType: DataType) = this.getClass == dataType.getClass } + +/** + * ::DeveloperApi:: + * The user defined type in Python. + * + * Note: This can only be accessed via Python UDF, or accessed as serialized object. + */ +private[sql] class PythonUserDefinedType( + val sqlType: DataType, + override val pyUDT: String, + override val serializedPyClass: String) extends UserDefinedType[Any] { + + /* The serialization is handled by UDT class in Python */ + override def serialize(obj: Any): Any = obj + override def deserialize(datam: Any): Any = datam + + /* There is no Java class for Python UDT */ + override def userClass: java.lang.Class[Any] = null + + override private[sql] def jsonValue: JValue = { + ("type" -> "udt") ~ + ("pyClass" -> pyUDT) ~ + ("serializedClass" -> serializedPyClass) ~ + ("sqlType" -> sqlType.jsonValue) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index ed645b618dc9b..a86cefe941e8e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -145,15 +145,15 @@ class AnalysisSuite extends AnalysisTest { 'e / 'e as 'div5)) val pl = plan.asInstanceOf[Project].projectList - // StringType will be promoted into Double assert(pl(0).dataType == DoubleType) assert(pl(1).dataType == DoubleType) assert(pl(2).dataType == DoubleType) - assert(pl(3).dataType == DoubleType) + // StringType will be promoted into Decimal(38, 18) + assert(pl(3).dataType == DecimalType(38, 29)) assert(pl(4).dataType == DoubleType) } - test("pull out nondeterministic expressions from unary LogicalPlan") { + test("pull out nondeterministic expressions from RepartitionByExpression") { val plan = RepartitionByExpression(Seq(Rand(33)), testRelation) val projected = Alias(Rand(33), "_nondeterministic")() val expected = @@ -162,4 +162,42 @@ class AnalysisSuite extends AnalysisTest { Project(testRelation.output :+ projected, testRelation))) checkAnalysis(plan, expected) } + + test("pull out nondeterministic expressions from Sort") { + val plan = Sort(Seq(SortOrder(Rand(33), Ascending)), false, testRelation) + val analyzed = caseSensitiveAnalyzer.execute(plan) + analyzed.transform { + case s: Sort if s.expressions.exists(!_.deterministic) => + fail("nondeterministic expressions are not allowed in Sort") + } + } + + test("remove still-need-evaluate ordering expressions from sort") { + val a = testRelation2.output(0) + val b = testRelation2.output(1) + + def makeOrder(e: Expression): SortOrder = SortOrder(e, Ascending) + + val noEvalOrdering = makeOrder(a) + val noEvalOrderingWithAlias = makeOrder(Alias(Alias(b, "name1")(), "name2")()) + + val needEvalExpr = Coalesce(Seq(a, Literal("1"))) + val needEvalExpr2 = Coalesce(Seq(a, b)) + val needEvalOrdering = makeOrder(needEvalExpr) + val needEvalOrdering2 = makeOrder(needEvalExpr2) + + val plan = Sort( + Seq(noEvalOrdering, noEvalOrderingWithAlias, needEvalOrdering, needEvalOrdering2), + false, testRelation2) + + val evaluatedOrdering = makeOrder(AttributeReference("_sortCondition", StringType)()) + val materializedExprs = Seq(needEvalExpr, needEvalExpr2).map(e => Alias(e, "_sortCondition")()) + + val expected = + Project(testRelation2.output, + Sort(Seq(makeOrder(a), makeOrder(b), evaluatedOrdering, evaluatedOrdering), false, + Project(testRelation2.output ++ materializedExprs, testRelation2))) + + checkAnalysis(plan, expected) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala index ad15136ee9a2f..a52e4cb4dfd9f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala @@ -53,7 +53,7 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { } test("check types for unary arithmetic") { - assertError(UnaryMinus('stringField), "type (numeric or interval)") + assertError(UnaryMinus('stringField), "type (numeric or calendarinterval)") assertError(Abs('stringField), "expected to be of type numeric") assertError(BitwiseNot('stringField), "expected to be of type integral") } @@ -78,8 +78,9 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { assertErrorForDifferingTypes(MaxOf('intField, 'booleanField)) assertErrorForDifferingTypes(MinOf('intField, 'booleanField)) - assertError(Add('booleanField, 'booleanField), "accepts (numeric or interval) type") - assertError(Subtract('booleanField, 'booleanField), "accepts (numeric or interval) type") + assertError(Add('booleanField, 'booleanField), "accepts (numeric or calendarinterval) type") + assertError(Subtract('booleanField, 'booleanField), + "accepts (numeric or calendarinterval) type") assertError(Multiply('booleanField, 'booleanField), "accepts numeric type") assertError(Divide('booleanField, 'booleanField), "accepts numeric type") assertError(Remainder('booleanField, 'booleanField), "accepts numeric type") @@ -166,10 +167,13 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { CreateNamedStruct(Seq("a", "b", 2.0)), "even number of arguments") assertError( CreateNamedStruct(Seq(1, "a", "b", 2.0)), - "Odd position only allow foldable and not-null StringType expressions") + "Only foldable StringType expressions are allowed to appear at odd position") assertError( CreateNamedStruct(Seq('a.string.at(0), "a", "b", 2.0)), - "Odd position only allow foldable and not-null StringType expressions") + "Only foldable StringType expressions are allowed to appear at odd position") + assertError( + CreateNamedStruct(Seq(Literal.create(null, StringType), "a")), + "Field name should not be null") } test("check types for ROUND") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala index 4454d51b75877..70608771dd110 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala @@ -17,12 +17,15 @@ package org.apache.spark.sql.catalyst.analysis +import java.sql.Timestamp + import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.CalendarInterval class HiveTypeCoercionSuite extends PlanTest { @@ -116,7 +119,7 @@ class HiveTypeCoercionSuite extends PlanTest { shouldNotCast(IntegerType, MapType) shouldNotCast(IntegerType, StructType) - shouldNotCast(IntervalType, StringType) + shouldNotCast(CalendarIntervalType, StringType) // Don't implicitly cast complex types to string. shouldNotCast(ArrayType(StringType), StringType) @@ -400,6 +403,33 @@ class HiveTypeCoercionSuite extends PlanTest { } } + test("rule for date/timestamp operations") { + val dateTimeOperations = HiveTypeCoercion.DateTimeOperations + val date = Literal(new java.sql.Date(0L)) + val timestamp = Literal(new Timestamp(0L)) + val interval = Literal(new CalendarInterval(0, 0)) + val str = Literal("2015-01-01") + + ruleTest(dateTimeOperations, Add(date, interval), Cast(TimeAdd(date, interval), DateType)) + ruleTest(dateTimeOperations, Add(interval, date), Cast(TimeAdd(date, interval), DateType)) + ruleTest(dateTimeOperations, Add(timestamp, interval), + Cast(TimeAdd(timestamp, interval), TimestampType)) + ruleTest(dateTimeOperations, Add(interval, timestamp), + Cast(TimeAdd(timestamp, interval), TimestampType)) + ruleTest(dateTimeOperations, Add(str, interval), Cast(TimeAdd(str, interval), StringType)) + ruleTest(dateTimeOperations, Add(interval, str), Cast(TimeAdd(str, interval), StringType)) + + ruleTest(dateTimeOperations, Subtract(date, interval), Cast(TimeSub(date, interval), DateType)) + ruleTest(dateTimeOperations, Subtract(timestamp, interval), + Cast(TimeSub(timestamp, interval), TimestampType)) + ruleTest(dateTimeOperations, Subtract(str, interval), Cast(TimeSub(str, interval), StringType)) + + // interval operations should not be effected + ruleTest(dateTimeOperations, Add(interval, interval), Add(interval, interval)) + ruleTest(dateTimeOperations, Subtract(interval, interval), Subtract(interval, interval)) + } + + /** * There are rules that need to not fire before child expressions get resolved. * We use this test to make sure those rules do not fire early. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala index e7e5231d32c9e..d03b0fbbfb2b2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala @@ -116,9 +116,12 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper test("Abs") { testNumericDataTypes { convert => + val input = Literal(convert(1)) + val dataType = input.dataType checkEvaluation(Abs(Literal(convert(0))), convert(0)) checkEvaluation(Abs(Literal(convert(1))), convert(1)) checkEvaluation(Abs(Literal(convert(-1))), convert(1)) + checkEvaluation(Abs(Literal.create(null, dataType)), null) } } @@ -170,6 +173,6 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(Pmod(-7, 3), 2) checkEvaluation(Pmod(7.2D, 4.1D), 3.1000000000000005) checkEvaluation(Pmod(Decimal(0.7), Decimal(0.2)), Decimal(0.1)) - checkEvaluation(Pmod(2L, Long.MaxValue), 2) + checkEvaluation(Pmod(2L, Long.MaxValue), 2L) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/BitwiseFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/BitwiseFunctionsSuite.scala index 648fbf5a4c30b..fa30fbe528479 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/BitwiseFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/BitwiseFunctionsSuite.scala @@ -30,8 +30,9 @@ class BitwiseFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(expr, expected) } - check(1.toByte, ~1.toByte) - check(1000.toShort, ~1000.toShort) + // Need the extra toByte even though IntelliJ thought it's not needed. + check(1.toByte, (~1.toByte).toByte) + check(1000.toShort, (~1000.toShort).toShort) check(1000000, ~1000000) check(123456789123L, ~123456789123L) @@ -45,8 +46,9 @@ class BitwiseFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(expr, expected) } - check(1.toByte, 2.toByte, 1.toByte & 2.toByte) - check(1000.toShort, 2.toShort, 1000.toShort & 2.toShort) + // Need the extra toByte even though IntelliJ thought it's not needed. + check(1.toByte, 2.toByte, (1.toByte & 2.toByte).toByte) + check(1000.toShort, 2.toShort, (1000.toShort & 2.toShort).toShort) check(1000000, 4, 1000000 & 4) check(123456789123L, 5L, 123456789123L & 5L) @@ -63,8 +65,9 @@ class BitwiseFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(expr, expected) } - check(1.toByte, 2.toByte, 1.toByte | 2.toByte) - check(1000.toShort, 2.toShort, 1000.toShort | 2.toShort) + // Need the extra toByte even though IntelliJ thought it's not needed. + check(1.toByte, 2.toByte, (1.toByte | 2.toByte).toByte) + check(1000.toShort, 2.toShort, (1000.toShort | 2.toShort).toShort) check(1000000, 4, 1000000 | 4) check(123456789123L, 5L, 123456789123L | 5L) @@ -81,8 +84,9 @@ class BitwiseFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(expr, expected) } - check(1.toByte, 2.toByte, 1.toByte ^ 2.toByte) - check(1000.toShort, 2.toShort, 1000.toShort ^ 2.toShort) + // Need the extra toByte even though IntelliJ thought it's not needed. + check(1.toByte, 2.toByte, (1.toByte ^ 2.toByte).toByte) + check(1000.toShort, 2.toShort, (1000.toShort ^ 2.toShort).toShort) check(1000000, 4, 1000000 ^ 4) check(123456789123L, 5L, 123456789123L ^ 5L) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index 408353cf70a49..1ad70733eae03 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -21,6 +21,7 @@ import java.sql.{Timestamp, Date} import java.util.{TimeZone, Calendar} import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ @@ -43,6 +44,42 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(cast(v, Literal(expected).dataType), expected) } + private def checkNullCast(from: DataType, to: DataType): Unit = { + checkEvaluation(Cast(Literal.create(null, from), to), null) + } + + test("null cast") { + import DataTypeTestUtils._ + + // follow [[org.apache.spark.sql.catalyst.expressions.Cast.canCast]] logic + // to ensure we test every possible cast situation here + atomicTypes.zip(atomicTypes).foreach { case (from, to) => + checkNullCast(from, to) + } + + atomicTypes.foreach(dt => checkNullCast(NullType, dt)) + atomicTypes.foreach(dt => checkNullCast(dt, StringType)) + checkNullCast(StringType, BinaryType) + checkNullCast(StringType, BooleanType) + checkNullCast(DateType, BooleanType) + checkNullCast(TimestampType, BooleanType) + numericTypes.foreach(dt => checkNullCast(dt, BooleanType)) + + checkNullCast(StringType, TimestampType) + checkNullCast(BooleanType, TimestampType) + checkNullCast(DateType, TimestampType) + numericTypes.foreach(dt => checkNullCast(dt, TimestampType)) + + atomicTypes.foreach(dt => checkNullCast(dt, DateType)) + + checkNullCast(StringType, CalendarIntervalType) + numericTypes.foreach(dt => checkNullCast(StringType, dt)) + numericTypes.foreach(dt => checkNullCast(BooleanType, dt)) + numericTypes.foreach(dt => checkNullCast(DateType, dt)) + numericTypes.foreach(dt => checkNullCast(TimestampType, dt)) + for (from <- numericTypes; to <- numericTypes) checkNullCast(from, to) + } + test("cast string to date") { var c = Calendar.getInstance() c.set(2015, 0, 1, 0, 0, 0) @@ -69,8 +106,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { } test("cast string to timestamp") { - checkEvaluation(Cast(Literal("123"), TimestampType), - null) + checkEvaluation(Cast(Literal("123"), TimestampType), null) var c = Calendar.getInstance() c.set(2015, 0, 1, 0, 0, 0) @@ -206,10 +242,9 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(cast(123L, DecimalType.USER_DEFAULT), Decimal(123)) checkEvaluation(cast(123L, DecimalType(3, 0)), Decimal(123)) - checkEvaluation(cast(123L, DecimalType(3, 1)), Decimal(123.0)) + checkEvaluation(cast(123L, DecimalType(3, 1)), null) - // TODO: Fix the following bug and re-enable it. - // checkEvaluation(cast(123L, DecimalType(2, 0)), null) + checkEvaluation(cast(123L, DecimalType(2, 0)), null) } test("cast from boolean") { @@ -473,6 +508,8 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { val array_notNull = Literal.create(Seq("123", "abc", ""), ArrayType(StringType, containsNull = false)) + checkNullCast(ArrayType(StringType), ArrayType(IntegerType)) + { val ret = cast(array, ArrayType(IntegerType, containsNull = true)) assert(ret.resolved === true) @@ -526,6 +563,8 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { Map("a" -> "123", "b" -> "abc", "c" -> ""), MapType(StringType, StringType, valueContainsNull = false)) + checkNullCast(MapType(StringType, IntegerType), MapType(StringType, StringType)) + { val ret = cast(map, MapType(StringType, IntegerType, valueContainsNull = true)) assert(ret.resolved === true) @@ -580,6 +619,14 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { } test("cast from struct") { + checkNullCast( + StructType(Seq( + StructField("a", StringType), + StructField("b", IntegerType))), + StructType(Seq( + StructField("a", StringType), + StructField("b", StringType)))) + val struct = Literal.create( InternalRow( UTF8String.fromString("123"), @@ -683,13 +730,10 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { test("complex casting") { val complex = Literal.create( - InternalRow( - Seq(UTF8String.fromString("123"), UTF8String.fromString("abc"), UTF8String.fromString("")), - Map( - UTF8String.fromString("a") -> UTF8String.fromString("123"), - UTF8String.fromString("b") -> UTF8String.fromString("abc"), - UTF8String.fromString("c") -> UTF8String.fromString("")), - InternalRow(0)), + Row( + Seq("123", "abc", ""), + Map("a" ->"123", "b" -> "abc", "c" -> ""), + Row(0)), StructType(Seq( StructField("a", ArrayType(StringType, containsNull = false), nullable = true), @@ -709,23 +753,20 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { StructField("l", LongType, nullable = true))))))) assert(ret.resolved === true) - checkEvaluation(ret, InternalRow( + checkEvaluation(ret, Row( Seq(123, null, null), - Map( - UTF8String.fromString("a") -> true, - UTF8String.fromString("b") -> true, - UTF8String.fromString("c") -> false), - InternalRow(0L))) + Map("a" -> true, "b" -> true, "c" -> false), + Row(0L))) } test("case between string and interval") { - import org.apache.spark.unsafe.types.Interval + import org.apache.spark.unsafe.types.CalendarInterval - checkEvaluation(Cast(Literal("interval -3 month 7 hours"), IntervalType), - new Interval(-3, 7 * Interval.MICROS_PER_HOUR)) + checkEvaluation(Cast(Literal("interval -3 month 7 hours"), CalendarIntervalType), + new CalendarInterval(-3, 7 * CalendarInterval.MICROS_PER_HOUR)) checkEvaluation(Cast(Literal.create( - new Interval(15, -3 * Interval.MICROS_PER_DAY), IntervalType), StringType), + new CalendarInterval(15, -3 * CalendarInterval.MICROS_PER_DAY), CalendarIntervalType), + StringType), "interval 1 years 3 months -3 days") } - } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala index fc842772f3480..3fa246b69d1f1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala @@ -110,7 +110,7 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { expr.dataType match { case ArrayType(StructType(fields), containsNull) => val field = fields.find(_.name == fieldName).get - GetArrayStructFields(expr, field, fields.indexOf(field), containsNull) + GetArrayStructFields(expr, field, fields.indexOf(field), fields.length, containsNull) } } @@ -132,6 +132,7 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(CreateArray(intWithNull), intSeq :+ null, EmptyRow) checkEvaluation(CreateArray(longWithNull), longSeq :+ null, EmptyRow) checkEvaluation(CreateArray(strWithNull), strSeq :+ null, EmptyRow) + checkEvaluation(CreateArray(Literal.create(null, IntegerType) :: Nil), null :: Nil) } test("CreateStruct") { @@ -139,26 +140,20 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { val c1 = 'a.int.at(0) val c3 = 'c.int.at(2) checkEvaluation(CreateStruct(Seq(c1, c3)), create_row(1, 3), row) + checkEvaluation(CreateStruct(Literal.create(null, LongType) :: Nil), create_row(null)) } test("CreateNamedStruct") { - val row = InternalRow(1, 2, 3) + val row = create_row(1, 2, 3) val c1 = 'a.int.at(0) val c3 = 'c.int.at(2) - checkEvaluation(CreateNamedStruct(Seq("a", c1, "b", c3)), InternalRow(1, 3), row) - } - - test("CreateNamedStruct with literal field") { - val row = InternalRow(1, 2, 3) - val c1 = 'a.int.at(0) + checkEvaluation(CreateNamedStruct(Seq("a", c1, "b", c3)), create_row(1, 3), row) checkEvaluation(CreateNamedStruct(Seq("a", c1, "b", "y")), - InternalRow(1, UTF8String.fromString("y")), row) - } - - test("CreateNamedStruct from all literal fields") { - checkEvaluation( - CreateNamedStruct(Seq("a", "x", "b", 2.0)), - InternalRow(UTF8String.fromString("x"), 2.0), InternalRow.empty) + create_row(1, UTF8String.fromString("y")), row) + checkEvaluation(CreateNamedStruct(Seq("a", "x", "b", 2.0)), + create_row(UTF8String.fromString("x"), 2.0)) + checkEvaluation(CreateNamedStruct(Seq("a", Literal.create(null, IntegerType))), + create_row(null)) } test("test dsl for complex type") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala index b31d6661c8c1c..d26bcdb2902ab 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala @@ -149,6 +149,8 @@ class ConditionalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(Least(Seq(c1, c2, Literal(-1))), -1, row) checkEvaluation(Least(Seq(c4, c5, c3, c3, Literal("a"))), "a", row) + val nullLiteral = Literal.create(null, IntegerType) + checkEvaluation(Least(Seq(nullLiteral, nullLiteral)), null) checkEvaluation(Least(Seq(Literal(null), Literal(null))), null, InternalRow.empty) checkEvaluation(Least(Seq(Literal(-1.0), Literal(2.5))), -1.0, InternalRow.empty) checkEvaluation(Least(Seq(Literal(-1), Literal(2))), -1, InternalRow.empty) @@ -188,6 +190,8 @@ class ConditionalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(Greatest(Seq(c1, c2, Literal(2))), 2, row) checkEvaluation(Greatest(Seq(c4, c5, c3, Literal("ccc"))), "ccc", row) + val nullLiteral = Literal.create(null, IntegerType) + checkEvaluation(Greatest(Seq(nullLiteral, nullLiteral)), null) checkEvaluation(Greatest(Seq(Literal(null), Literal(null))), null, InternalRow.empty) checkEvaluation(Greatest(Seq(Literal(-1.0), Literal(2.5))), 2.5, InternalRow.empty) checkEvaluation(Greatest(Seq(Literal(-1), Literal(2))), 2, InternalRow.empty) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala index bdba6ce891386..6c15c05da3094 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala @@ -17,12 +17,14 @@ package org.apache.spark.sql.catalyst.expressions -import java.sql.{Timestamp, Date} +import java.sql.{Date, Timestamp} import java.text.SimpleDateFormat import java.util.Calendar import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.types.{StringType, TimestampType, DateType} +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.unsafe.types.CalendarInterval +import org.apache.spark.sql.types._ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -31,58 +33,23 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val d = new Date(sdf.parse("2015-04-08 13:10:15").getTime) val ts = new Timestamp(sdf.parse("2013-11-08 13:10:15").getTime) + test("datetime function current_date") { + val d0 = DateTimeUtils.millisToDays(System.currentTimeMillis()) + val cd = CurrentDate().eval(EmptyRow).asInstanceOf[Int] + val d1 = DateTimeUtils.millisToDays(System.currentTimeMillis()) + assert(d0 <= cd && cd <= d1 && d1 - d0 <= 1) + } + + test("datetime function current_timestamp") { + val ct = DateTimeUtils.toJavaTimestamp(CurrentTimestamp().eval(EmptyRow).asInstanceOf[Long]) + val t1 = System.currentTimeMillis() + assert(math.abs(t1 - ct.getTime) < 5000) + } + test("DayOfYear") { val sdfDay = new SimpleDateFormat("D") - (2002 to 2004).foreach { y => - (0 to 11).foreach { m => - (0 to 5).foreach { i => - val c = Calendar.getInstance() - c.set(y, m, 28, 0, 0, 0) - c.add(Calendar.DATE, i) - checkEvaluation(DayOfYear(Literal(new Date(c.getTimeInMillis))), - sdfDay.format(c.getTime).toInt) - } - } - } - (1998 to 2002).foreach { y => - (0 to 11).foreach { m => - (0 to 5).foreach { i => - val c = Calendar.getInstance() - c.set(y, m, 28, 0, 0, 0) - c.add(Calendar.DATE, i) - checkEvaluation(DayOfYear(Literal(new Date(c.getTimeInMillis))), - sdfDay.format(c.getTime).toInt) - } - } - } - - (1969 to 1970).foreach { y => - (0 to 11).foreach { m => - (0 to 5).foreach { i => - val c = Calendar.getInstance() - c.set(y, m, 28, 0, 0, 0) - c.add(Calendar.DATE, i) - checkEvaluation(DayOfYear(Literal(new Date(c.getTimeInMillis))), - sdfDay.format(c.getTime).toInt) - } - } - } - - (2402 to 2404).foreach { y => - (0 to 11).foreach { m => - (0 to 5).foreach { i => - val c = Calendar.getInstance() - c.set(y, m, 28, 0, 0, 0) - c.add(Calendar.DATE, i) - checkEvaluation(DayOfYear(Literal(new Date(c.getTimeInMillis))), - sdfDay.format(c.getTime).toInt) - } - } - } - - (2398 to 2402).foreach { y => - (0 to 11).foreach { m => + (0 to 3).foreach { m => (0 to 5).foreach { i => val c = Calendar.getInstance() c.set(y, m, 28, 0, 0, 0) @@ -92,6 +59,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } } } + checkEvaluation(DayOfYear(Literal.create(null, DateType)), null) } test("Year") { @@ -101,7 +69,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Year(Cast(Literal(ts), DateType)), 2013) val c = Calendar.getInstance() - (2000 to 2010).foreach { y => + (2000 to 2002).foreach { y => (0 to 11 by 11).foreach { m => c.set(y, m, 28) (0 to 5 * 24).foreach { i => @@ -139,20 +107,8 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Month(Cast(Literal(ts), DateType)), 11) (2003 to 2004).foreach { y => - (0 to 11).foreach { m => - (0 to 5 * 24).foreach { i => - val c = Calendar.getInstance() - c.set(y, m, 28, 0, 0, 0) - c.add(Calendar.HOUR_OF_DAY, i) - checkEvaluation(Month(Literal(new Date(c.getTimeInMillis))), - c.get(Calendar.MONTH) + 1) - } - } - } - - (1999 to 2000).foreach { y => - (0 to 11).foreach { m => - (0 to 5 * 24).foreach { i => + (0 to 3).foreach { m => + (0 to 2 * 24).foreach { i => val c = Calendar.getInstance() c.set(y, m, 28, 0, 0, 0) c.add(Calendar.HOUR_OF_DAY, i) @@ -246,4 +202,235 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } } + test("date_add") { + checkEvaluation( + DateAdd(Literal(Date.valueOf("2016-02-28")), Literal(1)), + DateTimeUtils.fromJavaDate(Date.valueOf("2016-02-29"))) + checkEvaluation( + DateAdd(Literal(Date.valueOf("2016-02-28")), Literal(-365)), + DateTimeUtils.fromJavaDate(Date.valueOf("2015-02-28"))) + checkEvaluation(DateAdd(Literal.create(null, DateType), Literal(1)), null) + checkEvaluation(DateAdd(Literal(Date.valueOf("2016-02-28")), Literal.create(null, IntegerType)), + null) + checkEvaluation(DateAdd(Literal.create(null, DateType), Literal.create(null, IntegerType)), + null) + } + + test("date_sub") { + checkEvaluation( + DateSub(Literal(Date.valueOf("2015-01-01")), Literal(1)), + DateTimeUtils.fromJavaDate(Date.valueOf("2014-12-31"))) + checkEvaluation( + DateSub(Literal(Date.valueOf("2015-01-01")), Literal(-1)), + DateTimeUtils.fromJavaDate(Date.valueOf("2015-01-02"))) + checkEvaluation(DateSub(Literal.create(null, DateType), Literal(1)), null) + checkEvaluation(DateSub(Literal(Date.valueOf("2016-02-28")), Literal.create(null, IntegerType)), + null) + checkEvaluation(DateSub(Literal.create(null, DateType), Literal.create(null, IntegerType)), + null) + } + + test("time_add") { + checkEvaluation( + TimeAdd(Literal(Timestamp.valueOf("2016-01-29 10:00:00")), + Literal(new CalendarInterval(1, 123000L))), + DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2016-02-29 10:00:00.123"))) + + checkEvaluation( + TimeAdd(Literal.create(null, TimestampType), Literal(new CalendarInterval(1, 123000L))), + null) + checkEvaluation( + TimeAdd(Literal(Timestamp.valueOf("2016-01-29 10:00:00")), + Literal.create(null, CalendarIntervalType)), + null) + checkEvaluation( + TimeAdd(Literal.create(null, TimestampType), Literal.create(null, CalendarIntervalType)), + null) + } + + test("time_sub") { + checkEvaluation( + TimeSub(Literal(Timestamp.valueOf("2016-03-31 10:00:00")), + Literal(new CalendarInterval(1, 0))), + DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2016-02-29 10:00:00"))) + checkEvaluation( + TimeSub( + Literal(Timestamp.valueOf("2016-03-30 00:00:01")), + Literal(new CalendarInterval(1, 2000000.toLong))), + DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2016-02-28 23:59:59"))) + + checkEvaluation( + TimeSub(Literal.create(null, TimestampType), Literal(new CalendarInterval(1, 123000L))), + null) + checkEvaluation( + TimeSub(Literal(Timestamp.valueOf("2016-01-29 10:00:00")), + Literal.create(null, CalendarIntervalType)), + null) + checkEvaluation( + TimeSub(Literal.create(null, TimestampType), Literal.create(null, CalendarIntervalType)), + null) + } + + test("add_months") { + checkEvaluation(AddMonths(Literal(Date.valueOf("2015-01-30")), Literal(1)), + DateTimeUtils.fromJavaDate(Date.valueOf("2015-02-28"))) + checkEvaluation(AddMonths(Literal(Date.valueOf("2016-03-30")), Literal(-1)), + DateTimeUtils.fromJavaDate(Date.valueOf("2016-02-29"))) + checkEvaluation( + AddMonths(Literal(Date.valueOf("2015-01-30")), Literal.create(null, IntegerType)), + null) + checkEvaluation(AddMonths(Literal.create(null, DateType), Literal(1)), null) + checkEvaluation(AddMonths(Literal.create(null, DateType), Literal.create(null, IntegerType)), + null) + } + + test("months_between") { + checkEvaluation( + MonthsBetween(Literal(Timestamp.valueOf("1997-02-28 10:30:00")), + Literal(Timestamp.valueOf("1996-10-30 00:00:00"))), + 3.94959677) + checkEvaluation( + MonthsBetween(Literal(Timestamp.valueOf("2015-01-30 11:52:00")), + Literal(Timestamp.valueOf("2015-01-30 11:50:00"))), + 0.0) + checkEvaluation( + MonthsBetween(Literal(Timestamp.valueOf("2015-01-31 00:00:00")), + Literal(Timestamp.valueOf("2015-03-31 22:00:00"))), + -2.0) + checkEvaluation( + MonthsBetween(Literal(Timestamp.valueOf("2015-03-31 22:00:00")), + Literal(Timestamp.valueOf("2015-02-28 00:00:00"))), + 1.0) + val t = Literal(Timestamp.valueOf("2015-03-31 22:00:00")) + val tnull = Literal.create(null, TimestampType) + checkEvaluation(MonthsBetween(t, tnull), null) + checkEvaluation(MonthsBetween(tnull, t), null) + checkEvaluation(MonthsBetween(tnull, tnull), null) + } + + test("last_day") { + checkEvaluation(LastDay(Literal(Date.valueOf("2015-02-28"))), Date.valueOf("2015-02-28")) + checkEvaluation(LastDay(Literal(Date.valueOf("2015-03-27"))), Date.valueOf("2015-03-31")) + checkEvaluation(LastDay(Literal(Date.valueOf("2015-04-26"))), Date.valueOf("2015-04-30")) + checkEvaluation(LastDay(Literal(Date.valueOf("2015-05-25"))), Date.valueOf("2015-05-31")) + checkEvaluation(LastDay(Literal(Date.valueOf("2015-06-24"))), Date.valueOf("2015-06-30")) + checkEvaluation(LastDay(Literal(Date.valueOf("2015-07-23"))), Date.valueOf("2015-07-31")) + checkEvaluation(LastDay(Literal(Date.valueOf("2015-08-01"))), Date.valueOf("2015-08-31")) + checkEvaluation(LastDay(Literal(Date.valueOf("2015-09-02"))), Date.valueOf("2015-09-30")) + checkEvaluation(LastDay(Literal(Date.valueOf("2015-10-03"))), Date.valueOf("2015-10-31")) + checkEvaluation(LastDay(Literal(Date.valueOf("2015-11-04"))), Date.valueOf("2015-11-30")) + checkEvaluation(LastDay(Literal(Date.valueOf("2015-12-05"))), Date.valueOf("2015-12-31")) + checkEvaluation(LastDay(Literal(Date.valueOf("2016-01-06"))), Date.valueOf("2016-01-31")) + checkEvaluation(LastDay(Literal(Date.valueOf("2016-02-07"))), Date.valueOf("2016-02-29")) + checkEvaluation(LastDay(Literal.create(null, DateType)), null) + } + + test("next_day") { + def testNextDay(input: String, dayOfWeek: String, output: String): Unit = { + checkEvaluation( + NextDay(Literal(Date.valueOf(input)), NonFoldableLiteral(dayOfWeek)), + DateTimeUtils.fromJavaDate(Date.valueOf(output))) + checkEvaluation( + NextDay(Literal(Date.valueOf(input)), Literal(dayOfWeek)), + DateTimeUtils.fromJavaDate(Date.valueOf(output))) + } + testNextDay("2015-07-23", "Mon", "2015-07-27") + testNextDay("2015-07-23", "mo", "2015-07-27") + testNextDay("2015-07-23", "Tue", "2015-07-28") + testNextDay("2015-07-23", "tu", "2015-07-28") + testNextDay("2015-07-23", "we", "2015-07-29") + testNextDay("2015-07-23", "wed", "2015-07-29") + testNextDay("2015-07-23", "Thu", "2015-07-30") + testNextDay("2015-07-23", "TH", "2015-07-30") + testNextDay("2015-07-23", "Fri", "2015-07-24") + testNextDay("2015-07-23", "fr", "2015-07-24") + + checkEvaluation(NextDay(Literal(Date.valueOf("2015-07-23")), Literal("xx")), null) + checkEvaluation(NextDay(Literal.create(null, DateType), Literal("xx")), null) + checkEvaluation( + NextDay(Literal(Date.valueOf("2015-07-23")), Literal.create(null, StringType)), null) + } + + test("function to_date") { + checkEvaluation( + ToDate(Literal(Date.valueOf("2015-07-22"))), + DateTimeUtils.fromJavaDate(Date.valueOf("2015-07-22"))) + checkEvaluation(ToDate(Literal.create(null, DateType)), null) + } + + test("function trunc") { + def testTrunc(input: Date, fmt: String, expected: Date): Unit = { + checkEvaluation(TruncDate(Literal.create(input, DateType), Literal.create(fmt, StringType)), + expected) + checkEvaluation( + TruncDate(Literal.create(input, DateType), NonFoldableLiteral.create(fmt, StringType)), + expected) + } + val date = Date.valueOf("2015-07-22") + Seq("yyyy", "YYYY", "year", "YEAR", "yy", "YY").foreach{ fmt => + testTrunc(date, fmt, Date.valueOf("2015-01-01")) + } + Seq("month", "MONTH", "mon", "MON", "mm", "MM").foreach { fmt => + testTrunc(date, fmt, Date.valueOf("2015-07-01")) + } + testTrunc(date, "DD", null) + testTrunc(date, null, null) + testTrunc(null, "MON", null) + testTrunc(null, null, null) + } + + test("from_unixtime") { + val sdf1 = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss") + val fmt2 = "yyyy-MM-dd HH:mm:ss.SSS" + val sdf2 = new SimpleDateFormat(fmt2) + checkEvaluation( + FromUnixTime(Literal(0L), Literal("yyyy-MM-dd HH:mm:ss")), sdf1.format(new Timestamp(0))) + checkEvaluation(FromUnixTime( + Literal(1000L), Literal("yyyy-MM-dd HH:mm:ss")), sdf1.format(new Timestamp(1000000))) + checkEvaluation( + FromUnixTime(Literal(-1000L), Literal(fmt2)), sdf2.format(new Timestamp(-1000000))) + checkEvaluation( + FromUnixTime(Literal.create(null, LongType), Literal.create(null, StringType)), null) + checkEvaluation( + FromUnixTime(Literal.create(null, LongType), Literal("yyyy-MM-dd HH:mm:ss")), null) + checkEvaluation(FromUnixTime(Literal(1000L), Literal.create(null, StringType)), null) + checkEvaluation( + FromUnixTime(Literal(0L), Literal("not a valid format")), null) + } + + test("unix_timestamp") { + val sdf1 = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss") + val fmt2 = "yyyy-MM-dd HH:mm:ss.SSS" + val sdf2 = new SimpleDateFormat(fmt2) + val fmt3 = "yy-MM-dd" + val sdf3 = new SimpleDateFormat(fmt3) + val date1 = Date.valueOf("2015-07-24") + checkEvaluation( + UnixTimestamp(Literal(sdf1.format(new Timestamp(0))), Literal("yyyy-MM-dd HH:mm:ss")), 0L) + checkEvaluation(UnixTimestamp( + Literal(sdf1.format(new Timestamp(1000000))), Literal("yyyy-MM-dd HH:mm:ss")), 1000L) + checkEvaluation( + UnixTimestamp(Literal(new Timestamp(1000000)), Literal("yyyy-MM-dd HH:mm:ss")), 1000L) + checkEvaluation( + UnixTimestamp(Literal(date1), Literal("yyyy-MM-dd HH:mm:ss")), + DateTimeUtils.daysToMillis(DateTimeUtils.fromJavaDate(date1)) / 1000L) + checkEvaluation( + UnixTimestamp(Literal(sdf2.format(new Timestamp(-1000000))), Literal(fmt2)), -1000L) + checkEvaluation(UnixTimestamp( + Literal(sdf3.format(Date.valueOf("2015-07-24"))), Literal(fmt3)), + DateTimeUtils.daysToMillis(DateTimeUtils.fromJavaDate(Date.valueOf("2015-07-24"))) / 1000L) + val t1 = UnixTimestamp( + CurrentTimestamp(), Literal("yyyy-MM-dd HH:mm:ss")).eval().asInstanceOf[Long] + val t2 = UnixTimestamp( + CurrentTimestamp(), Literal("yyyy-MM-dd HH:mm:ss")).eval().asInstanceOf[Long] + assert(t2 - t1 <= 1) + checkEvaluation( + UnixTimestamp(Literal.create(null, DateType), Literal.create(null, StringType)), null) + checkEvaluation( + UnixTimestamp(Literal.create(null, DateType), Literal("yyyy-MM-dd HH:mm:ss")), null) + checkEvaluation(UnixTimestamp( + Literal(date1), Literal.create(null, StringType)), date1.getTime / 1000L) + checkEvaluation( + UnixTimestamp(Literal("2015-07-24"), Literal("not a valid format")), null) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DatetimeFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DatetimeFunctionsSuite.scala deleted file mode 100644 index 1618c24871c60..0000000000000 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DatetimeFunctionsSuite.scala +++ /dev/null @@ -1,37 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions - -import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.util.DateTimeUtils - -class DatetimeFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { - test("datetime function current_date") { - val d0 = DateTimeUtils.millisToDays(System.currentTimeMillis()) - val cd = CurrentDate().eval(EmptyRow).asInstanceOf[Int] - val d1 = DateTimeUtils.millisToDays(System.currentTimeMillis()) - assert(d0 <= cd && cd <= d1 && d1 - d0 <= 1) - } - - test("datetime function current_timestamp") { - val ct = DateTimeUtils.toJavaTimestamp(CurrentTimestamp().eval(EmptyRow).asInstanceOf[Long]) - val t1 = System.currentTimeMillis() - assert(math.abs(t1 - ct.getTime) < 5000) - } - -} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index ab0cdc857c80e..3c05e5c3b833c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -65,7 +65,7 @@ trait ExpressionEvalHelper { protected def evaluate(expression: Expression, inputRow: InternalRow = EmptyRow): Any = { expression.foreach { - case n: Nondeterministic => n.initialize() + case n: Nondeterministic => n.setInitialValues() case _ => } expression.eval(inputRow) @@ -82,6 +82,7 @@ trait ExpressionEvalHelper { s""" |Code generation of $expression failed: |$e + |${e.getStackTraceString} """.stripMargin) } } @@ -114,7 +115,7 @@ trait ExpressionEvalHelper { val actual = plan(inputRow).get(0, expression.dataType) if (!checkResult(actual, expected)) { val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" - fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expected$input") + fail(s"Incorrect evaluation: $expression, actual: $actual, expected: $expected$input") } } @@ -146,7 +147,8 @@ trait ExpressionEvalHelper { if (actual != expectedRow) { val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" - fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expectedRow$input") + fail("Incorrect Evaluation in codegen mode: " + + s"$expression, actual: $actual, expected: $expectedRow$input") } if (actual.copy() != expectedRow) { fail(s"Copy of generated Row is wrong: actual: ${actual.copy()}, expected: $expectedRow") @@ -163,12 +165,21 @@ trait ExpressionEvalHelper { expression) val unsafeRow = plan(inputRow) - // UnsafeRow cannot be compared with GenericInternalRow directly - val actual = FromUnsafeProjection(expression.dataType :: Nil)(unsafeRow) - val expectedRow = InternalRow(expected) - if (actual != expectedRow) { - val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" - fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expectedRow$input") + val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" + + if (expected == null) { + if (!unsafeRow.isNullAt(0)) { + val expectedRow = InternalRow(expected) + fail("Incorrect evaluation in unsafe mode: " + + s"$expression, actual: $unsafeRow, expected: $expectedRow$input") + } + } else { + val lit = InternalRow(expected) + val expectedRow = UnsafeProjection.create(Array(expression.dataType)).apply(lit) + if (unsafeRow != expectedRow) { + fail("Incorrect evaluation in unsafe mode: " + + s"$expression, actual: $unsafeRow, expected: $expectedRow$input") + } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala index 21459a7c69838..9fcb548af6bbb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala @@ -110,35 +110,17 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(c(Literal(1.0), Literal.create(null, DoubleType)), null, create_row(null)) } - test("conv") { - checkEvaluation(Conv(Literal("3"), Literal(10), Literal(2)), "11") - checkEvaluation(Conv(Literal("-15"), Literal(10), Literal(-16)), "-F") - checkEvaluation(Conv(Literal("-15"), Literal(10), Literal(16)), "FFFFFFFFFFFFFFF1") - checkEvaluation(Conv(Literal("big"), Literal(36), Literal(16)), "3A48") - checkEvaluation(Conv(Literal.create(null, StringType), Literal(36), Literal(16)), null) - checkEvaluation(Conv(Literal("3"), Literal.create(null, IntegerType), Literal(16)), null) - checkEvaluation( - Conv(Literal("1234"), Literal(10), Literal(37)), null) - checkEvaluation( - Conv(Literal(""), Literal(10), Literal(16)), null) - checkEvaluation( - Conv(Literal("9223372036854775807"), Literal(36), Literal(16)), "FFFFFFFFFFFFFFFF") - // If there is an invalid digit in the number, the longest valid prefix should be converted. - checkEvaluation( - Conv(Literal("11abc"), Literal(10), Literal(16)), "B") - } - private def checkNaN( - expression: Expression, inputRow: InternalRow = EmptyRow): Unit = { + expression: Expression, inputRow: InternalRow = EmptyRow): Unit = { checkNaNWithoutCodegen(expression, inputRow) checkNaNWithGeneratedProjection(expression, inputRow) checkNaNWithOptimization(expression, inputRow) } private def checkNaNWithoutCodegen( - expression: Expression, - expected: Any, - inputRow: InternalRow = EmptyRow): Unit = { + expression: Expression, + expected: Any, + inputRow: InternalRow = EmptyRow): Unit = { val actual = try evaluate(expression, inputRow) catch { case e: Exception => fail(s"Exception evaluating $expression", e) } @@ -149,7 +131,6 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { } } - private def checkNaNWithGeneratedProjection( expression: Expression, inputRow: InternalRow = EmptyRow): Unit = { @@ -172,6 +153,25 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkNaNWithoutCodegen(optimizedPlan.expressions.head, inputRow) } + test("conv") { + checkEvaluation(Conv(Literal("3"), Literal(10), Literal(2)), "11") + checkEvaluation(Conv(Literal("-15"), Literal(10), Literal(-16)), "-F") + checkEvaluation(Conv(Literal("-15"), Literal(10), Literal(16)), "FFFFFFFFFFFFFFF1") + checkEvaluation(Conv(Literal("big"), Literal(36), Literal(16)), "3A48") + checkEvaluation(Conv(Literal.create(null, StringType), Literal(36), Literal(16)), null) + checkEvaluation(Conv(Literal("3"), Literal.create(null, IntegerType), Literal(16)), null) + checkEvaluation(Conv(Literal("3"), Literal(16), Literal.create(null, IntegerType)), null) + checkEvaluation( + Conv(Literal("1234"), Literal(10), Literal(37)), null) + checkEvaluation( + Conv(Literal(""), Literal(10), Literal(16)), null) + checkEvaluation( + Conv(Literal("9223372036854775807"), Literal(36), Literal(16)), "FFFFFFFFFFFFFFFF") + // If there is an invalid digit in the number, the longest valid prefix should be converted. + checkEvaluation( + Conv(Literal("11abc"), Literal(10), Literal(16)), "B") + } + test("e") { testLeaf(EulerNumber, math.E) } @@ -417,7 +417,7 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("round") { - val domain = -6 to 6 + val scales = -6 to 6 val doublePi: Double = math.Pi val shortPi: Short = 31415 val intPi: Int = 314159265 @@ -437,17 +437,16 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { 31415926535900000L, 31415926535898000L, 31415926535897900L, 31415926535897930L) ++ Seq.fill(7)(31415926535897932L) - val bdResults: Seq[BigDecimal] = Seq(BigDecimal(3.0), BigDecimal(3.1), BigDecimal(3.14), - BigDecimal(3.142), BigDecimal(3.1416), BigDecimal(3.14159), - BigDecimal(3.141593), BigDecimal(3.1415927)) - - domain.zipWithIndex.foreach { case (scale, i) => + scales.zipWithIndex.foreach { case (scale, i) => checkEvaluation(Round(doublePi, scale), doubleResults(i), EmptyRow) checkEvaluation(Round(shortPi, scale), shortResults(i), EmptyRow) checkEvaluation(Round(intPi, scale), intResults(i), EmptyRow) checkEvaluation(Round(longPi, scale), longResults(i), EmptyRow) } + val bdResults: Seq[BigDecimal] = Seq(BigDecimal(3.0), BigDecimal(3.1), BigDecimal(3.14), + BigDecimal(3.142), BigDecimal(3.1416), BigDecimal(3.14159), + BigDecimal(3.141593), BigDecimal(3.1415927)) // round_scale > current_scale would result in precision increase // and not allowed by o.a.s.s.types.Decimal.changePrecision, therefore null (0 to 7).foreach { i => @@ -456,5 +455,11 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { (8 to 10).foreach { scale => checkEvaluation(Round(bdPi, scale), null, EmptyRow) } + + DataTypeTestUtils.numericTypes.foreach { dataType => + checkEvaluation(Round(Literal.create(null, dataType), Literal(2)), null) + checkEvaluation(Round(Literal.create(null, dataType), + Literal.create(null, IntegerType)), null) + } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NonFoldableLiteral.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NonFoldableLiteral.scala new file mode 100644 index 0000000000000..31ecf4a9e810a --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NonFoldableLiteral.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.types._ + + +/** + * A literal value that is not foldable. Used in expression codegen testing to test code path + * that behave differently based on foldable values. + */ +case class NonFoldableLiteral(value: Any, dataType: DataType) + extends LeafExpression with CodegenFallback { + + override def foldable: Boolean = false + override def nullable: Boolean = true + + override def toString: String = if (value != null) value.toString else "null" + + override def eval(input: InternalRow): Any = value + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + Literal.create(value, dataType).genCode(ctx, ev) + } +} + + +object NonFoldableLiteral { + def apply(value: Any): NonFoldableLiteral = { + val lit = Literal(value) + NonFoldableLiteral(lit.value, lit.dataType) + } + def create(value: Any, dataType: DataType): NonFoldableLiteral = { + val lit = Literal.create(value, dataType) + NonFoldableLiteral(lit.value, lit.dataType) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/expression/NondeterministicSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NondeterministicSuite.scala similarity index 76% rename from sql/core/src/test/scala/org/apache/spark/sql/execution/expression/NondeterministicSuite.scala rename to sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NondeterministicSuite.scala index 99e11fd64b2b9..bf1c930c0bd0b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/expression/NondeterministicSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NondeterministicSuite.scala @@ -15,18 +15,20 @@ * limitations under the License. */ -package org.apache.spark.sql.execution.expression +package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.expressions. ExpressionEvalHelper -import org.apache.spark.sql.execution.expressions.{SparkPartitionID, MonotonicallyIncreasingID} class NondeterministicSuite extends SparkFunSuite with ExpressionEvalHelper { test("MonotonicallyIncreasingID") { - checkEvaluation(MonotonicallyIncreasingID(), 0) + checkEvaluation(MonotonicallyIncreasingID(), 0L) } test("SparkPartitionID") { - checkEvaluation(SparkPartitionID, 0) + checkEvaluation(SparkPartitionID(), 0) + } + + test("InputFileName") { + checkEvaluation(InputFileName(), "") } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RandomSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RandomSuite.scala index 698c81ba24482..4a644d136f09c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RandomSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RandomSuite.scala @@ -20,9 +20,6 @@ package org.apache.spark.sql.catalyst.expressions import org.scalatest.Matchers._ import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.types.DoubleType - class RandomSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -30,4 +27,9 @@ class RandomSuite extends SparkFunSuite with ExpressionEvalHelper { checkDoubleEvaluation(Rand(30), 0.7363714192755834 +- 0.001) checkDoubleEvaluation(Randn(30), 0.5181478766595276 +- 0.001) } + + test("SPARK-9127 codegen with long seed") { + checkDoubleEvaluation(Rand(5419823303878592871L), 0.4061913198963727 +- 0.001) + checkDoubleEvaluation(Randn(5419823303878592871L), -0.24417152005343168 +- 0.001) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index 3d294fda5d103..07b952531ec2e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -348,6 +348,9 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(StringTrimLeft(s), "花花世界 ", create_row(" 花花世界 ")) checkEvaluation(StringTrim(s), "花花世界", create_row(" 花花世界 ")) // scalastyle:on + checkEvaluation(StringTrim(Literal.create(null, StringType)), null) + checkEvaluation(StringTrimLeft(Literal.create(null, StringType)), null) + checkEvaluation(StringTrimRight(Literal.create(null, StringType)), null) } test("FORMAT") { @@ -391,6 +394,9 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val s3 = 'c.string.at(2) val s4 = 'd.int.at(3) val row1 = create_row("aaads", "aa", "zz", 1) + val row2 = create_row(null, "aa", "zz", 0) + val row3 = create_row("aaads", null, "zz", 0) + val row4 = create_row(null, null, null, 0) checkEvaluation(new StringLocate(Literal("aa"), Literal("aaads")), 1, row1) checkEvaluation(StringLocate(Literal("aa"), Literal("aaads"), Literal(1)), 2, row1) @@ -402,6 +408,9 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(StringLocate(s2, s1, s4), 2, row1) checkEvaluation(new StringLocate(s3, s1), 0, row1) checkEvaluation(StringLocate(s3, s1, Literal.create(null, IntegerType)), 0, row1) + checkEvaluation(new StringLocate(s2, s1), null, row2) + checkEvaluation(new StringLocate(s2, s1), null, row3) + checkEvaluation(new StringLocate(s2, s1, Literal.create(null, IntegerType)), 0, row4) } test("LPAD/RPAD") { @@ -448,6 +457,7 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val row1 = create_row("abccc") checkEvaluation(StringReverse(Literal("abccc")), "cccba", row1) checkEvaluation(StringReverse(s), "cccba", row1) + checkEvaluation(StringReverse(Literal.create(null, StringType)), null, row1) } test("SPACE") { @@ -466,6 +476,9 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val row1 = create_row("100-200", "(\\d+)", "num") val row2 = create_row("100-200", "(\\d+)", "###") val row3 = create_row("100-200", "(-)", "###") + val row4 = create_row(null, "(\\d+)", "###") + val row5 = create_row("100-200", null, "###") + val row6 = create_row("100-200", "(-)", null) val s = 's.string.at(0) val p = 'p.string.at(1) @@ -475,6 +488,9 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(expr, "num-num", row1) checkEvaluation(expr, "###-###", row2) checkEvaluation(expr, "100###200", row3) + checkEvaluation(expr, null, row4) + checkEvaluation(expr, null, row5) + checkEvaluation(expr, null, row6) } test("RegexExtract") { @@ -482,6 +498,9 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val row2 = create_row("100-200", "(\\d+)-(\\d+)", 2) val row3 = create_row("100-200", "(\\d+).*", 1) val row4 = create_row("100-200", "([a-z])", 1) + val row5 = create_row(null, "([a-z])", 1) + val row6 = create_row("100-200", null, 1) + val row7 = create_row("100-200", "([a-z])", null) val s = 's.string.at(0) val p = 'p.string.at(1) @@ -492,6 +511,9 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(expr, "200", row2) checkEvaluation(expr, "100", row3) checkEvaluation(expr, "", row4) // will not match anything, empty string get + checkEvaluation(expr, null, row5) + checkEvaluation(expr, null, row6) + checkEvaluation(expr, null, row7) val expr1 = new RegExpExtract(s, p) checkEvaluation(expr1, "100", row1) @@ -501,11 +523,15 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val s1 = 'a.string.at(0) val s2 = 'b.string.at(1) val row1 = create_row("aa2bb3cc", "[1-9]+") + val row2 = create_row(null, "[1-9]+") + val row3 = create_row("aa2bb3cc", null) checkEvaluation( StringSplit(Literal("aa2bb3cc"), Literal("[1-9]+")), Seq("aa", "bb", "cc"), row1) checkEvaluation( StringSplit(s1, s2), Seq("aa", "bb", "cc"), row1) + checkEvaluation(StringSplit(s1, s2), null, row2) + checkEvaluation(StringSplit(s1, s2), null, row3) } test("length for string / binary") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala index 48b7dc57451a3..c6b4c729de2f9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala @@ -39,6 +39,7 @@ class UnsafeFixedWidthAggregationMapSuite private val groupKeySchema = StructType(StructField("product", StringType) :: Nil) private val aggBufferSchema = StructType(StructField("salePrice", IntegerType) :: Nil) private def emptyAggregationBuffer: InternalRow = InternalRow(0) + private val PAGE_SIZE_BYTES: Long = 1L << 26; // 64 megabytes private var memoryManager: TaskMemoryManager = null @@ -54,13 +55,13 @@ class UnsafeFixedWidthAggregationMapSuite } test("supported schemas") { + assert(supportsAggregationBufferSchema( + StructType(StructField("x", DecimalType.USER_DEFAULT) :: Nil))) + assert(!supportsAggregationBufferSchema( + StructType(StructField("x", DecimalType.SYSTEM_DEFAULT) :: Nil))) assert(!supportsAggregationBufferSchema(StructType(StructField("x", StringType) :: Nil))) - assert(supportsGroupKeySchema(StructType(StructField("x", StringType) :: Nil))) - assert( !supportsAggregationBufferSchema(StructType(StructField("x", ArrayType(IntegerType)) :: Nil))) - assert( - !supportsGroupKeySchema(StructType(StructField("x", ArrayType(IntegerType)) :: Nil))) } test("empty map") { @@ -69,7 +70,8 @@ class UnsafeFixedWidthAggregationMapSuite aggBufferSchema, groupKeySchema, memoryManager, - 1024, // initial capacity + 1024, // initial capacity, + PAGE_SIZE_BYTES, false // disable perf metrics ) assert(!map.iterator().hasNext) @@ -83,6 +85,7 @@ class UnsafeFixedWidthAggregationMapSuite groupKeySchema, memoryManager, 1024, // initial capacity + PAGE_SIZE_BYTES, false // disable perf metrics ) val groupKey = InternalRow(UTF8String.fromString("cats")) @@ -109,6 +112,7 @@ class UnsafeFixedWidthAggregationMapSuite groupKeySchema, memoryManager, 128, // initial capacity + PAGE_SIZE_BYTES, false // disable perf metrics ) val rand = new Random(42) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala index 2834b54e8fb2e..a0e1701339ea7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala @@ -46,7 +46,6 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { assert(unsafeRow.getLong(1) === 1) assert(unsafeRow.getInt(2) === 2) - // We can copy UnsafeRows as long as they don't reference ObjectPools val unsafeRowCopy = unsafeRow.copy() assert(unsafeRowCopy.getLong(0) === 0) assert(unsafeRowCopy.getLong(1) === 1) @@ -122,8 +121,8 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { FloatType, DoubleType, StringType, - BinaryType - // DecimalType.Default, + BinaryType, + DecimalType.USER_DEFAULT // ArrayType(IntegerType) ) val converter = UnsafeProjection.create(fieldTypes) @@ -146,11 +145,11 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { assert(createdFromNull.getShort(3) === 0) assert(createdFromNull.getInt(4) === 0) assert(createdFromNull.getLong(5) === 0) - assert(java.lang.Float.isNaN(createdFromNull.getFloat(6))) - assert(java.lang.Double.isNaN(createdFromNull.getDouble(7))) + assert(createdFromNull.getFloat(6) === 0.0f) + assert(createdFromNull.getDouble(7) === 0.0d) assert(createdFromNull.getUTF8String(8) === null) assert(createdFromNull.getBinary(9) === null) - // assert(createdFromNull.get(10) === null) + assert(createdFromNull.getDecimal(10, 10, 0) === null) // assert(createdFromNull.get(11) === null) // If we have an UnsafeRow with columns that are initially non-null and we null out those @@ -168,7 +167,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { r.setDouble(7, 700) r.update(8, UTF8String.fromString("hello")) r.update(9, "world".getBytes) - // r.update(10, Decimal(10)) + r.setDecimal(10, Decimal(10), 10) // r.update(11, Array(11)) r } @@ -184,7 +183,8 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { assert(setToNullAfterCreation.getDouble(7) === rowWithNoNullColumns.getDouble(7)) assert(setToNullAfterCreation.getString(8) === rowWithNoNullColumns.getString(8)) assert(setToNullAfterCreation.getBinary(9) === rowWithNoNullColumns.getBinary(9)) - // assert(setToNullAfterCreation.get(10) === rowWithNoNullColumns.get(10)) + assert(setToNullAfterCreation.getDecimal(10, 10, 0) === + rowWithNoNullColumns.getDecimal(10, 10, 0)) // assert(setToNullAfterCreation.get(11) === rowWithNoNullColumns.get(11)) for (i <- fieldTypes.indices) { @@ -203,7 +203,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { setToNullAfterCreation.setDouble(7, 700) // setToNullAfterCreation.update(8, UTF8String.fromString("hello")) // setToNullAfterCreation.update(9, "world".getBytes) - // setToNullAfterCreation.update(10, Decimal(10)) + setToNullAfterCreation.setDecimal(10, Decimal(10), 10) // setToNullAfterCreation.update(11, Array(11)) assert(setToNullAfterCreation.isNullAt(0) === rowWithNoNullColumns.isNullAt(0)) @@ -216,7 +216,8 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { assert(setToNullAfterCreation.getDouble(7) === rowWithNoNullColumns.getDouble(7)) // assert(setToNullAfterCreation.getString(8) === rowWithNoNullColumns.getString(8)) // assert(setToNullAfterCreation.get(9) === rowWithNoNullColumns.get(9)) - // assert(setToNullAfterCreation.get(10) === rowWithNoNullColumns.get(10)) + assert(setToNullAfterCreation.getDecimal(10, 10, 0) === + rowWithNoNullColumns.getDecimal(10, 10, 0)) // assert(setToNullAfterCreation.get(11) === rowWithNoNullColumns.get(11)) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatterSuite.scala index 478702fea6146..46daa3eb8bf80 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatterSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatterSuite.scala @@ -73,4 +73,34 @@ class CodeFormatterSuite extends SparkFunSuite { |} """.stripMargin } + + testCase("if else on the same line") { + """ + |class A { + | if (c) {duh;} else {boo;} + |} + """.stripMargin + }{ + """ + |class A { + | if (c) {duh;} else {boo;} + |} + """.stripMargin + } + + testCase("function calls") { + """ + |foo( + |a, + |b, + |c) + """.stripMargin + }{ + """ + |foo( + | a, + | b, + | c) + """.stripMargin + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenExpressionCachingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenExpressionCachingSuite.scala new file mode 100644 index 0000000000000..2d3f98dbbd3d1 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenExpressionCachingSuite.scala @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.codegen + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types.{BooleanType, DataType} + +/** + * A test suite that makes sure code generation handles expression internally states correctly. + */ +class CodegenExpressionCachingSuite extends SparkFunSuite { + + test("GenerateUnsafeProjection should initialize expressions") { + // Use an Add to wrap two of them together in case we only initialize the top level expressions. + val expr = And(NondeterministicExpression(), NondeterministicExpression()) + val instance = UnsafeProjection.create(Seq(expr)) + assert(instance.apply(null).getBoolean(0) === false) + } + + test("GenerateProjection should initialize expressions") { + val expr = And(NondeterministicExpression(), NondeterministicExpression()) + val instance = GenerateProjection.generate(Seq(expr)) + assert(instance.apply(null).getBoolean(0) === false) + } + + test("GenerateMutableProjection should initialize expressions") { + val expr = And(NondeterministicExpression(), NondeterministicExpression()) + val instance = GenerateMutableProjection.generate(Seq(expr))() + assert(instance.apply(null).getBoolean(0) === false) + } + + test("GeneratePredicate should initialize expressions") { + val expr = And(NondeterministicExpression(), NondeterministicExpression()) + val instance = GeneratePredicate.generate(expr) + assert(instance.apply(null) === false) + } + + test("GenerateUnsafeProjection should not share expression instances") { + val expr1 = MutableExpression() + val instance1 = UnsafeProjection.create(Seq(expr1)) + assert(instance1.apply(null).getBoolean(0) === false) + + val expr2 = MutableExpression() + expr2.mutableState = true + val instance2 = UnsafeProjection.create(Seq(expr2)) + assert(instance1.apply(null).getBoolean(0) === false) + assert(instance2.apply(null).getBoolean(0) === true) + } + + test("GenerateProjection should not share expression instances") { + val expr1 = MutableExpression() + val instance1 = GenerateProjection.generate(Seq(expr1)) + assert(instance1.apply(null).getBoolean(0) === false) + + val expr2 = MutableExpression() + expr2.mutableState = true + val instance2 = GenerateProjection.generate(Seq(expr2)) + assert(instance1.apply(null).getBoolean(0) === false) + assert(instance2.apply(null).getBoolean(0) === true) + } + + test("GenerateMutableProjection should not share expression instances") { + val expr1 = MutableExpression() + val instance1 = GenerateMutableProjection.generate(Seq(expr1))() + assert(instance1.apply(null).getBoolean(0) === false) + + val expr2 = MutableExpression() + expr2.mutableState = true + val instance2 = GenerateMutableProjection.generate(Seq(expr2))() + assert(instance1.apply(null).getBoolean(0) === false) + assert(instance2.apply(null).getBoolean(0) === true) + } + + test("GeneratePredicate should not share expression instances") { + val expr1 = MutableExpression() + val instance1 = GeneratePredicate.generate(expr1) + assert(instance1.apply(null) === false) + + val expr2 = MutableExpression() + expr2.mutableState = true + val instance2 = GeneratePredicate.generate(expr2) + assert(instance1.apply(null) === false) + assert(instance2.apply(null) === true) + } + +} + +/** + * An expression that's non-deterministic and doesn't support codegen. + */ +case class NondeterministicExpression() + extends LeafExpression with Nondeterministic with CodegenFallback { + override protected def initInternal(): Unit = { } + override protected def evalInternal(input: InternalRow): Any = false + override def nullable: Boolean = false + override def dataType: DataType = BooleanType +} + + +/** + * An expression with mutable state so we can change it freely in our test suite. + */ +case class MutableExpression() extends LeafExpression with CodegenFallback { + var mutableState: Boolean = false + override def eval(input: InternalRow): Any = mutableState + + override def nullable: Boolean = false + override def dataType: DataType = BooleanType +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala index fab9eb9cd4c9f..60d2bcfe13757 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala @@ -19,47 +19,48 @@ package org.apache.spark.sql.catalyst.util import java.sql.{Date, Timestamp} import java.text.SimpleDateFormat -import java.util.{TimeZone, Calendar} +import java.util.{Calendar, TimeZone} import org.apache.spark.SparkFunSuite import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.sql.catalyst.util.DateTimeUtils._ class DateTimeUtilsSuite extends SparkFunSuite { private[this] def getInUTCDays(timestamp: Long): Int = { val tz = TimeZone.getDefault - ((timestamp + tz.getOffset(timestamp)) / DateTimeUtils.MILLIS_PER_DAY).toInt + ((timestamp + tz.getOffset(timestamp)) / MILLIS_PER_DAY).toInt } test("timestamp and us") { val now = new Timestamp(System.currentTimeMillis()) now.setNanos(1000) - val ns = DateTimeUtils.fromJavaTimestamp(now) + val ns = fromJavaTimestamp(now) assert(ns % 1000000L === 1) - assert(DateTimeUtils.toJavaTimestamp(ns) === now) + assert(toJavaTimestamp(ns) === now) List(-111111111111L, -1L, 0, 1L, 111111111111L).foreach { t => - val ts = DateTimeUtils.toJavaTimestamp(t) - assert(DateTimeUtils.fromJavaTimestamp(ts) === t) - assert(DateTimeUtils.toJavaTimestamp(DateTimeUtils.fromJavaTimestamp(ts)) === ts) + val ts = toJavaTimestamp(t) + assert(fromJavaTimestamp(ts) === t) + assert(toJavaTimestamp(fromJavaTimestamp(ts)) === ts) } } test("us and julian day") { - val (d, ns) = DateTimeUtils.toJulianDay(0) - assert(d === DateTimeUtils.JULIAN_DAY_OF_EPOCH) - assert(ns === DateTimeUtils.SECONDS_PER_DAY / 2 * DateTimeUtils.NANOS_PER_SECOND) - assert(DateTimeUtils.fromJulianDay(d, ns) == 0L) + val (d, ns) = toJulianDay(0) + assert(d === JULIAN_DAY_OF_EPOCH) + assert(ns === SECONDS_PER_DAY / 2 * NANOS_PER_SECOND) + assert(fromJulianDay(d, ns) == 0L) val t = new Timestamp(61394778610000L) // (2015, 6, 11, 10, 10, 10, 100) - val (d1, ns1) = DateTimeUtils.toJulianDay(DateTimeUtils.fromJavaTimestamp(t)) - val t2 = DateTimeUtils.toJavaTimestamp(DateTimeUtils.fromJulianDay(d1, ns1)) + val (d1, ns1) = toJulianDay(fromJavaTimestamp(t)) + val t2 = toJavaTimestamp(fromJulianDay(d1, ns1)) assert(t.equals(t2)) } test("SPARK-6785: java date conversion before and after epoch") { def checkFromToJavaDate(d1: Date): Unit = { - val d2 = DateTimeUtils.toJavaDate(DateTimeUtils.fromJavaDate(d1)) + val d2 = toJavaDate(fromJavaDate(d1)) assert(d2.toString === d1.toString) } @@ -95,157 +96,156 @@ class DateTimeUtilsSuite extends SparkFunSuite { } test("string to date") { - import DateTimeUtils.millisToDays var c = Calendar.getInstance() c.set(2015, 0, 28, 0, 0, 0) c.set(Calendar.MILLISECOND, 0) - assert(DateTimeUtils.stringToDate(UTF8String.fromString("2015-01-28")).get === + assert(stringToDate(UTF8String.fromString("2015-01-28")).get === millisToDays(c.getTimeInMillis)) c.set(2015, 0, 1, 0, 0, 0) c.set(Calendar.MILLISECOND, 0) - assert(DateTimeUtils.stringToDate(UTF8String.fromString("2015")).get === + assert(stringToDate(UTF8String.fromString("2015")).get === millisToDays(c.getTimeInMillis)) c = Calendar.getInstance() c.set(2015, 2, 1, 0, 0, 0) c.set(Calendar.MILLISECOND, 0) - assert(DateTimeUtils.stringToDate(UTF8String.fromString("2015-03")).get === + assert(stringToDate(UTF8String.fromString("2015-03")).get === millisToDays(c.getTimeInMillis)) c = Calendar.getInstance() c.set(2015, 2, 18, 0, 0, 0) c.set(Calendar.MILLISECOND, 0) - assert(DateTimeUtils.stringToDate(UTF8String.fromString("2015-03-18")).get === + assert(stringToDate(UTF8String.fromString("2015-03-18")).get === millisToDays(c.getTimeInMillis)) - assert(DateTimeUtils.stringToDate(UTF8String.fromString("2015-03-18 ")).get === + assert(stringToDate(UTF8String.fromString("2015-03-18 ")).get === millisToDays(c.getTimeInMillis)) - assert(DateTimeUtils.stringToDate(UTF8String.fromString("2015-03-18 123142")).get === + assert(stringToDate(UTF8String.fromString("2015-03-18 123142")).get === millisToDays(c.getTimeInMillis)) - assert(DateTimeUtils.stringToDate(UTF8String.fromString("2015-03-18T123123")).get === + assert(stringToDate(UTF8String.fromString("2015-03-18T123123")).get === millisToDays(c.getTimeInMillis)) - assert(DateTimeUtils.stringToDate(UTF8String.fromString("2015-03-18T")).get === + assert(stringToDate(UTF8String.fromString("2015-03-18T")).get === millisToDays(c.getTimeInMillis)) - assert(DateTimeUtils.stringToDate(UTF8String.fromString("2015-03-18X")).isEmpty) - assert(DateTimeUtils.stringToDate(UTF8String.fromString("2015/03/18")).isEmpty) - assert(DateTimeUtils.stringToDate(UTF8String.fromString("2015.03.18")).isEmpty) - assert(DateTimeUtils.stringToDate(UTF8String.fromString("20150318")).isEmpty) - assert(DateTimeUtils.stringToDate(UTF8String.fromString("2015-031-8")).isEmpty) + assert(stringToDate(UTF8String.fromString("2015-03-18X")).isEmpty) + assert(stringToDate(UTF8String.fromString("2015/03/18")).isEmpty) + assert(stringToDate(UTF8String.fromString("2015.03.18")).isEmpty) + assert(stringToDate(UTF8String.fromString("20150318")).isEmpty) + assert(stringToDate(UTF8String.fromString("2015-031-8")).isEmpty) } test("string to timestamp") { var c = Calendar.getInstance() c.set(1969, 11, 31, 16, 0, 0) c.set(Calendar.MILLISECOND, 0) - assert(DateTimeUtils.stringToTimestamp(UTF8String.fromString("1969-12-31 16:00:00")).get === + assert(stringToTimestamp(UTF8String.fromString("1969-12-31 16:00:00")).get === c.getTimeInMillis * 1000) c.set(2015, 0, 1, 0, 0, 0) c.set(Calendar.MILLISECOND, 0) - assert(DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015")).get === + assert(stringToTimestamp(UTF8String.fromString("2015")).get === c.getTimeInMillis * 1000) c = Calendar.getInstance() c.set(2015, 2, 1, 0, 0, 0) c.set(Calendar.MILLISECOND, 0) - assert(DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015-03")).get === + assert(stringToTimestamp(UTF8String.fromString("2015-03")).get === c.getTimeInMillis * 1000) c = Calendar.getInstance() c.set(2015, 2, 18, 0, 0, 0) c.set(Calendar.MILLISECOND, 0) - assert(DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015-03-18")).get === + assert(stringToTimestamp(UTF8String.fromString("2015-03-18")).get === c.getTimeInMillis * 1000) - assert(DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015-03-18 ")).get === + assert(stringToTimestamp(UTF8String.fromString("2015-03-18 ")).get === c.getTimeInMillis * 1000) - assert(DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015-03-18T")).get === + assert(stringToTimestamp(UTF8String.fromString("2015-03-18T")).get === c.getTimeInMillis * 1000) c = Calendar.getInstance() c.set(2015, 2, 18, 12, 3, 17) c.set(Calendar.MILLISECOND, 0) - assert(DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015-03-18 12:03:17")).get === + assert(stringToTimestamp(UTF8String.fromString("2015-03-18 12:03:17")).get === c.getTimeInMillis * 1000) - assert(DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015-03-18T12:03:17")).get === + assert(stringToTimestamp(UTF8String.fromString("2015-03-18T12:03:17")).get === c.getTimeInMillis * 1000) c = Calendar.getInstance(TimeZone.getTimeZone("GMT-13:53")) c.set(2015, 2, 18, 12, 3, 17) c.set(Calendar.MILLISECOND, 0) - assert(DateTimeUtils.stringToTimestamp( + assert(stringToTimestamp( UTF8String.fromString("2015-03-18T12:03:17-13:53")).get === c.getTimeInMillis * 1000) c = Calendar.getInstance(TimeZone.getTimeZone("UTC")) c.set(2015, 2, 18, 12, 3, 17) c.set(Calendar.MILLISECOND, 0) - assert(DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015-03-18T12:03:17Z")).get === + assert(stringToTimestamp(UTF8String.fromString("2015-03-18T12:03:17Z")).get === c.getTimeInMillis * 1000) - assert(DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015-03-18 12:03:17Z")).get === + assert(stringToTimestamp(UTF8String.fromString("2015-03-18 12:03:17Z")).get === c.getTimeInMillis * 1000) c = Calendar.getInstance(TimeZone.getTimeZone("GMT-01:00")) c.set(2015, 2, 18, 12, 3, 17) c.set(Calendar.MILLISECOND, 0) - assert(DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015-03-18T12:03:17-1:0")).get === + assert(stringToTimestamp(UTF8String.fromString("2015-03-18T12:03:17-1:0")).get === c.getTimeInMillis * 1000) - assert(DateTimeUtils.stringToTimestamp( + assert(stringToTimestamp( UTF8String.fromString("2015-03-18T12:03:17-01:00")).get === c.getTimeInMillis * 1000) c = Calendar.getInstance(TimeZone.getTimeZone("GMT+07:30")) c.set(2015, 2, 18, 12, 3, 17) c.set(Calendar.MILLISECOND, 0) - assert(DateTimeUtils.stringToTimestamp( + assert(stringToTimestamp( UTF8String.fromString("2015-03-18T12:03:17+07:30")).get === c.getTimeInMillis * 1000) c = Calendar.getInstance(TimeZone.getTimeZone("GMT+07:03")) c.set(2015, 2, 18, 12, 3, 17) c.set(Calendar.MILLISECOND, 0) - assert(DateTimeUtils.stringToTimestamp( + assert(stringToTimestamp( UTF8String.fromString("2015-03-18T12:03:17+07:03")).get === c.getTimeInMillis * 1000) c = Calendar.getInstance() c.set(2015, 2, 18, 12, 3, 17) c.set(Calendar.MILLISECOND, 123) - assert(DateTimeUtils.stringToTimestamp( + assert(stringToTimestamp( UTF8String.fromString("2015-03-18 12:03:17.123")).get === c.getTimeInMillis * 1000) - assert(DateTimeUtils.stringToTimestamp( + assert(stringToTimestamp( UTF8String.fromString("2015-03-18T12:03:17.123")).get === c.getTimeInMillis * 1000) c = Calendar.getInstance(TimeZone.getTimeZone("UTC")) c.set(2015, 2, 18, 12, 3, 17) c.set(Calendar.MILLISECOND, 456) - assert(DateTimeUtils.stringToTimestamp( + assert(stringToTimestamp( UTF8String.fromString("2015-03-18T12:03:17.456Z")).get === c.getTimeInMillis * 1000) - assert(DateTimeUtils.stringToTimestamp( + assert(stringToTimestamp( UTF8String.fromString("2015-03-18 12:03:17.456Z")).get === c.getTimeInMillis * 1000) c = Calendar.getInstance(TimeZone.getTimeZone("GMT-01:00")) c.set(2015, 2, 18, 12, 3, 17) c.set(Calendar.MILLISECOND, 123) - assert(DateTimeUtils.stringToTimestamp( + assert(stringToTimestamp( UTF8String.fromString("2015-03-18T12:03:17.123-1:0")).get === c.getTimeInMillis * 1000) - assert(DateTimeUtils.stringToTimestamp( + assert(stringToTimestamp( UTF8String.fromString("2015-03-18T12:03:17.123-01:00")).get === c.getTimeInMillis * 1000) c = Calendar.getInstance(TimeZone.getTimeZone("GMT+07:30")) c.set(2015, 2, 18, 12, 3, 17) c.set(Calendar.MILLISECOND, 123) - assert(DateTimeUtils.stringToTimestamp( + assert(stringToTimestamp( UTF8String.fromString("2015-03-18T12:03:17.123+07:30")).get === c.getTimeInMillis * 1000) c = Calendar.getInstance(TimeZone.getTimeZone("GMT+07:30")) c.set(2015, 2, 18, 12, 3, 17) c.set(Calendar.MILLISECOND, 123) - assert(DateTimeUtils.stringToTimestamp( + assert(stringToTimestamp( UTF8String.fromString("2015-03-18T12:03:17.123+07:30")).get === c.getTimeInMillis * 1000) c = Calendar.getInstance(TimeZone.getTimeZone("GMT+07:30")) c.set(2015, 2, 18, 12, 3, 17) c.set(Calendar.MILLISECOND, 123) - assert(DateTimeUtils.stringToTimestamp( + assert(stringToTimestamp( UTF8String.fromString("2015-03-18T12:03:17.123121+7:30")).get === c.getTimeInMillis * 1000 + 121) c = Calendar.getInstance(TimeZone.getTimeZone("GMT+07:30")) c.set(2015, 2, 18, 12, 3, 17) c.set(Calendar.MILLISECOND, 123) - assert(DateTimeUtils.stringToTimestamp( + assert(stringToTimestamp( UTF8String.fromString("2015-03-18T12:03:17.12312+7:30")).get === c.getTimeInMillis * 1000 + 120) @@ -254,7 +254,7 @@ class DateTimeUtilsSuite extends SparkFunSuite { c.set(Calendar.MINUTE, 12) c.set(Calendar.SECOND, 15) c.set(Calendar.MILLISECOND, 0) - assert(DateTimeUtils.stringToTimestamp( + assert(stringToTimestamp( UTF8String.fromString("18:12:15")).get === c.getTimeInMillis * 1000) @@ -263,7 +263,7 @@ class DateTimeUtilsSuite extends SparkFunSuite { c.set(Calendar.MINUTE, 12) c.set(Calendar.SECOND, 15) c.set(Calendar.MILLISECOND, 123) - assert(DateTimeUtils.stringToTimestamp( + assert(stringToTimestamp( UTF8String.fromString("T18:12:15.12312+7:30")).get === c.getTimeInMillis * 1000 + 120) @@ -272,93 +272,130 @@ class DateTimeUtilsSuite extends SparkFunSuite { c.set(Calendar.MINUTE, 12) c.set(Calendar.SECOND, 15) c.set(Calendar.MILLISECOND, 123) - assert(DateTimeUtils.stringToTimestamp( + assert(stringToTimestamp( UTF8String.fromString("18:12:15.12312+7:30")).get === c.getTimeInMillis * 1000 + 120) c = Calendar.getInstance() c.set(2011, 4, 6, 7, 8, 9) c.set(Calendar.MILLISECOND, 100) - assert(DateTimeUtils.stringToTimestamp( + assert(stringToTimestamp( UTF8String.fromString("2011-05-06 07:08:09.1000")).get === c.getTimeInMillis * 1000) - assert(DateTimeUtils.stringToTimestamp(UTF8String.fromString("238")).isEmpty) - assert(DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015-03-18 123142")).isEmpty) - assert(DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015-03-18T123123")).isEmpty) - assert(DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015-03-18X")).isEmpty) - assert(DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015/03/18")).isEmpty) - assert(DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015.03.18")).isEmpty) - assert(DateTimeUtils.stringToTimestamp(UTF8String.fromString("20150318")).isEmpty) - assert(DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015-031-8")).isEmpty) - assert(DateTimeUtils.stringToTimestamp( + assert(stringToTimestamp(UTF8String.fromString("238")).isEmpty) + assert(stringToTimestamp(UTF8String.fromString("2015-03-18 123142")).isEmpty) + assert(stringToTimestamp(UTF8String.fromString("2015-03-18T123123")).isEmpty) + assert(stringToTimestamp(UTF8String.fromString("2015-03-18X")).isEmpty) + assert(stringToTimestamp(UTF8String.fromString("2015/03/18")).isEmpty) + assert(stringToTimestamp(UTF8String.fromString("2015.03.18")).isEmpty) + assert(stringToTimestamp(UTF8String.fromString("20150318")).isEmpty) + assert(stringToTimestamp(UTF8String.fromString("2015-031-8")).isEmpty) + assert(stringToTimestamp( UTF8String.fromString("2015-03-18T12:03.17-20:0")).isEmpty) - assert(DateTimeUtils.stringToTimestamp( + assert(stringToTimestamp( UTF8String.fromString("2015-03-18T12:03.17-0:70")).isEmpty) - assert(DateTimeUtils.stringToTimestamp( + assert(stringToTimestamp( UTF8String.fromString("2015-03-18T12:03.17-1:0:0")).isEmpty) } test("hours") { val c = Calendar.getInstance() c.set(2015, 2, 18, 13, 2, 11) - assert(DateTimeUtils.getHours(c.getTimeInMillis * 1000) === 13) + assert(getHours(c.getTimeInMillis * 1000) === 13) c.set(2015, 12, 8, 2, 7, 9) - assert(DateTimeUtils.getHours(c.getTimeInMillis * 1000) === 2) + assert(getHours(c.getTimeInMillis * 1000) === 2) } test("minutes") { val c = Calendar.getInstance() c.set(2015, 2, 18, 13, 2, 11) - assert(DateTimeUtils.getMinutes(c.getTimeInMillis * 1000) === 2) + assert(getMinutes(c.getTimeInMillis * 1000) === 2) c.set(2015, 2, 8, 2, 7, 9) - assert(DateTimeUtils.getMinutes(c.getTimeInMillis * 1000) === 7) + assert(getMinutes(c.getTimeInMillis * 1000) === 7) } test("seconds") { val c = Calendar.getInstance() c.set(2015, 2, 18, 13, 2, 11) - assert(DateTimeUtils.getSeconds(c.getTimeInMillis * 1000) === 11) + assert(getSeconds(c.getTimeInMillis * 1000) === 11) c.set(2015, 2, 8, 2, 7, 9) - assert(DateTimeUtils.getSeconds(c.getTimeInMillis * 1000) === 9) + assert(getSeconds(c.getTimeInMillis * 1000) === 9) } test("get day in year") { val c = Calendar.getInstance() c.set(2015, 2, 18, 0, 0, 0) - assert(DateTimeUtils.getDayInYear(getInUTCDays(c.getTimeInMillis)) === 77) + assert(getDayInYear(getInUTCDays(c.getTimeInMillis)) === 77) c.set(2012, 2, 18, 0, 0, 0) - assert(DateTimeUtils.getDayInYear(getInUTCDays(c.getTimeInMillis)) === 78) + assert(getDayInYear(getInUTCDays(c.getTimeInMillis)) === 78) } test("get year") { val c = Calendar.getInstance() c.set(2015, 2, 18, 0, 0, 0) - assert(DateTimeUtils.getYear(getInUTCDays(c.getTimeInMillis)) === 2015) + assert(getYear(getInUTCDays(c.getTimeInMillis)) === 2015) c.set(2012, 2, 18, 0, 0, 0) - assert(DateTimeUtils.getYear(getInUTCDays(c.getTimeInMillis)) === 2012) + assert(getYear(getInUTCDays(c.getTimeInMillis)) === 2012) } test("get quarter") { val c = Calendar.getInstance() c.set(2015, 2, 18, 0, 0, 0) - assert(DateTimeUtils.getQuarter(getInUTCDays(c.getTimeInMillis)) === 1) + assert(getQuarter(getInUTCDays(c.getTimeInMillis)) === 1) c.set(2012, 11, 18, 0, 0, 0) - assert(DateTimeUtils.getQuarter(getInUTCDays(c.getTimeInMillis)) === 4) + assert(getQuarter(getInUTCDays(c.getTimeInMillis)) === 4) } test("get month") { val c = Calendar.getInstance() c.set(2015, 2, 18, 0, 0, 0) - assert(DateTimeUtils.getMonth(getInUTCDays(c.getTimeInMillis)) === 3) + assert(getMonth(getInUTCDays(c.getTimeInMillis)) === 3) c.set(2012, 11, 18, 0, 0, 0) - assert(DateTimeUtils.getMonth(getInUTCDays(c.getTimeInMillis)) === 12) + assert(getMonth(getInUTCDays(c.getTimeInMillis)) === 12) } test("get day of month") { val c = Calendar.getInstance() c.set(2015, 2, 18, 0, 0, 0) - assert(DateTimeUtils.getDayOfMonth(getInUTCDays(c.getTimeInMillis)) === 18) + assert(getDayOfMonth(getInUTCDays(c.getTimeInMillis)) === 18) c.set(2012, 11, 24, 0, 0, 0) - assert(DateTimeUtils.getDayOfMonth(getInUTCDays(c.getTimeInMillis)) === 24) + assert(getDayOfMonth(getInUTCDays(c.getTimeInMillis)) === 24) + } + + test("date add months") { + val c1 = Calendar.getInstance() + c1.set(1997, 1, 28, 10, 30, 0) + val days1 = millisToDays(c1.getTimeInMillis) + val c2 = Calendar.getInstance() + c2.set(2000, 1, 29) + assert(dateAddMonths(days1, 36) === millisToDays(c2.getTimeInMillis)) + c2.set(1996, 0, 31) + assert(dateAddMonths(days1, -13) === millisToDays(c2.getTimeInMillis)) + } + + test("timestamp add months") { + val c1 = Calendar.getInstance() + c1.set(1997, 1, 28, 10, 30, 0) + c1.set(Calendar.MILLISECOND, 0) + val ts1 = c1.getTimeInMillis * 1000L + val c2 = Calendar.getInstance() + c2.set(2000, 1, 29, 10, 30, 0) + c2.set(Calendar.MILLISECOND, 123) + val ts2 = c2.getTimeInMillis * 1000L + assert(timestampAddInterval(ts1, 36, 123000) === ts2) + } + + test("monthsBetween") { + val c1 = Calendar.getInstance() + c1.set(1997, 1, 28, 10, 30, 0) + val c2 = Calendar.getInstance() + c2.set(1996, 9, 30, 0, 0, 0) + assert(monthsBetween(c1.getTimeInMillis * 1000L, c2.getTimeInMillis * 1000L) === 3.94959677) + c2.set(2000, 1, 28, 0, 0, 0) + assert(monthsBetween(c1.getTimeInMillis * 1000L, c2.getTimeInMillis * 1000L) === -36) + c2.set(2000, 1, 29, 0, 0, 0) + assert(monthsBetween(c1.getTimeInMillis * 1000L, c2.getTimeInMillis * 1000L) === -36) + c2.set(1996, 2, 31, 0, 0, 0) + assert(monthsBetween(c1.getTimeInMillis * 1000L, c2.getTimeInMillis * 1000L) === 11) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 114ab91d10aa0..3ea0f9ed3bddd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -40,8 +40,9 @@ import org.apache.spark.sql.catalyst.plans.logical.{Filter, _} import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection, SqlParser} import org.apache.spark.sql.execution.{EvaluatePython, ExplainCommand, LogicalRDD} -import org.apache.spark.sql.execution.datasources.CreateTableUsingAsSelect -import org.apache.spark.sql.json.JacksonGenerator +import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, LogicalRelation} +import org.apache.spark.sql.json.{JacksonGenerator, JSONRelation} +import org.apache.spark.sql.sources.HadoopFsRelation import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils @@ -1546,6 +1547,21 @@ class DataFrame private[sql]( } } + /** + * Returns a best-effort snapshot of the files that compose this DataFrame. This method simply + * asks each constituent BaseRelation for its respective files and takes the union of all results. + * Depending on the source relations, this may not find all input files. Duplicates are removed. + */ + def inputFiles: Array[String] = { + val files: Seq[String] = logicalPlan.collect { + case LogicalRelation(fsBasedRelation: HadoopFsRelation) => + fsBasedRelation.paths.toSeq + case LogicalRelation(jsonRelation: JSONRelation) => + jsonRelation.path.toSeq + }.flatten + files.toSet.toArray + } + //////////////////////////////////////////////////////////////////////////// // for Python API //////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala index 4ec58082e7aef..2e68e358f2f1f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala @@ -17,6 +17,10 @@ package org.apache.spark.sql +import java.{util => ju, lang => jl} + +import scala.collection.JavaConverters._ + import org.apache.spark.annotation.Experimental import org.apache.spark.sql.execution.stat._ @@ -166,4 +170,42 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { def freqItems(cols: Seq[String]): DataFrame = { FrequentItems.singlePassFreqItems(df, cols, 0.01) } + + /** + * Returns a stratified sample without replacement based on the fraction given on each stratum. + * @param col column that defines strata + * @param fractions sampling fraction for each stratum. If a stratum is not specified, we treat + * its fraction as zero. + * @param seed random seed + * @tparam T stratum type + * @return a new [[DataFrame]] that represents the stratified sample + * + * @since 1.5.0 + */ + def sampleBy[T](col: String, fractions: Map[T, Double], seed: Long): DataFrame = { + require(fractions.values.forall(p => p >= 0.0 && p <= 1.0), + s"Fractions must be in [0, 1], but got $fractions.") + import org.apache.spark.sql.functions.{rand, udf} + val c = Column(col) + val r = rand(seed) + val f = udf { (stratum: Any, x: Double) => + x < fractions.getOrElse(stratum.asInstanceOf[T], 0.0) + } + df.filter(f(c, r)) + } + + /** + * Returns a stratified sample without replacement based on the fraction given on each stratum. + * @param col column that defines strata + * @param fractions sampling fraction for each stratum. If a stratum is not specified, we treat + * its fraction as zero. + * @param seed random seed + * @tparam T stratum type + * @return a new [[DataFrame]] that represents the stratified sample + * + * @since 1.5.0 + */ + def sampleBy[T](col: String, fractions: ju.Map[T, jl.Double], seed: Long): DataFrame = { + sampleBy(col, fractions.asScala.toMap.asInstanceOf[Map[T, Double]], seed) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 40eba33f595ca..6644e85d4a037 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -229,7 +229,7 @@ private[spark] object SQLConf { " a specific query.") val UNSAFE_ENABLED = booleanConf("spark.sql.unsafe.enabled", - defaultValue = Some(false), + defaultValue = Some(true), doc = "When true, use the new optimized Tungsten physical execution backend.") val DIALECT = stringConf( @@ -247,6 +247,13 @@ private[spark] object SQLConf { "otherwise the schema is picked from the summary file or a random data file " + "if no summary file is available.") + val PARQUET_SCHEMA_RESPECT_SUMMARIES = booleanConf("spark.sql.parquet.respectSummaryFiles", + defaultValue = Some(false), + doc = "When true, we make assumption that all part-files of Parquet are consistent with " + + "summary files and we will ignore them when merging schema. Otherwise, if this is " + + "false, which is the default, we will merge all part-files. This should be considered " + + "as expert-only option, and shouldn't be enabled before knowing what it means exactly.") + val PARQUET_BINARY_AS_STRING = booleanConf("spark.sql.parquet.binaryAsString", defaultValue = Some(false), doc = "Some other Parquet-producing systems, in particular Impala and older versions of " + @@ -322,7 +329,7 @@ private[spark] object SQLConf { " memory.") val SORTMERGE_JOIN = booleanConf("spark.sql.planner.sortMergeJoin", - defaultValue = Some(false), + defaultValue = Some(true), doc = "When true, use sort merge join (as opposed to hash join) by default for large joins.") // This is only used for the thriftserver diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala index 454b7b91a63f5..1620fc401ba6e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala @@ -114,7 +114,7 @@ private[sql] class FixedDecimalColumnBuilder( precision: Int, scale: Int) extends NativeColumnBuilder( - new FixedDecimalColumnStats, + new FixedDecimalColumnStats(precision, scale), FIXED_DECIMAL(precision, scale)) // TODO (lian) Add support for array, struct and map diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala index 32a84b2676e07..af1a8ecca9b57 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala @@ -234,14 +234,14 @@ private[sql] class BinaryColumnStats extends ColumnStats { InternalRow(null, null, nullCount, count, sizeInBytes) } -private[sql] class FixedDecimalColumnStats extends ColumnStats { +private[sql] class FixedDecimalColumnStats(precision: Int, scale: Int) extends ColumnStats { protected var upper: Decimal = null protected var lower: Decimal = null override def gatherStats(row: InternalRow, ordinal: Int): Unit = { super.gatherStats(row, ordinal) if (!row.isNullAt(ordinal)) { - val value = row.getDecimal(ordinal) + val value = row.getDecimal(ordinal, precision, scale) if (upper == null || value.compareTo(upper) > 0) upper = value if (lower == null || value.compareTo(lower) < 0) lower = value sizeInBytes += FIXED_DECIMAL.defaultSize diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala index 2863f6c230a9d..30f8fe320db3d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala @@ -392,7 +392,7 @@ private[sql] case class FIXED_DECIMAL(precision: Int, scale: Int) } override def getField(row: InternalRow, ordinal: Int): Decimal = { - row.getDecimal(ordinal) + row.getDecimal(ordinal, precision, scale) } override def setField(row: MutableRow, ordinal: Int, value: Decimal): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index 41a0c519ba527..6bd57f010a990 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -47,7 +47,12 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una override def canProcessSafeRows: Boolean = true - override def canProcessUnsafeRows: Boolean = true + override def canProcessUnsafeRows: Boolean = { + // Do not use the Unsafe path if we are using a RangePartitioning, since this may lead to + // an interpreted RowOrdering being applied to an UnsafeRow, which will lead to + // ClassCastExceptions at runtime. This check can be removed after SPARK-9054 is fixed. + !newPartitioning.isInstanceOf[RangePartitioning] + } /** * Determines whether records must be defensively copied before being sent to the shuffle. @@ -197,41 +202,6 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[ def apply(plan: SparkPlan): SparkPlan = plan.transformUp { case operator: SparkPlan => - // True iff every child's outputPartitioning satisfies the corresponding - // required data distribution. - def meetsRequirements: Boolean = - operator.requiredChildDistribution.zip(operator.children).forall { - case (required, child) => - val valid = child.outputPartitioning.satisfies(required) - logDebug( - s"${if (valid) "Valid" else "Invalid"} distribution," + - s"required: $required current: ${child.outputPartitioning}") - valid - } - - // True iff any of the children are incorrectly sorted. - def needsAnySort: Boolean = - operator.requiredChildOrdering.zip(operator.children).exists { - case (required, child) => required.nonEmpty && required != child.outputOrdering - } - - // True iff outputPartitionings of children are compatible with each other. - // It is possible that every child satisfies its required data distribution - // but two children have incompatible outputPartitionings. For example, - // A dataset is range partitioned by "a.asc" (RangePartitioning) and another - // dataset is hash partitioned by "a" (HashPartitioning). Tuples in these two - // datasets are both clustered by "a", but these two outputPartitionings are not - // compatible. - // TODO: ASSUMES TRANSITIVITY? - def compatible: Boolean = - operator.children - .map(_.outputPartitioning) - .sliding(2) - .forall { - case Seq(a) => true - case Seq(a, b) => a.compatibleWith(b) - } - // Adds Exchange or Sort operators as required def addOperatorsIfNecessary( partitioning: Partitioning, @@ -264,33 +234,26 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[ addSortIfNecessary(addShuffleIfNecessary(child)) } - if (meetsRequirements && compatible && !needsAnySort) { - operator - } else { - // At least one child does not satisfies its required data distribution or - // at least one child's outputPartitioning is not compatible with another child's - // outputPartitioning. In this case, we need to add Exchange operators. - val requirements = - (operator.requiredChildDistribution, operator.requiredChildOrdering, operator.children) - - val fixedChildren = requirements.zipped.map { - case (AllTuples, rowOrdering, child) => - addOperatorsIfNecessary(SinglePartition, rowOrdering, child) - case (ClusteredDistribution(clustering), rowOrdering, child) => - addOperatorsIfNecessary(HashPartitioning(clustering, numPartitions), rowOrdering, child) - case (OrderedDistribution(ordering), rowOrdering, child) => - addOperatorsIfNecessary(RangePartitioning(ordering, numPartitions), rowOrdering, child) + val requirements = + (operator.requiredChildDistribution, operator.requiredChildOrdering, operator.children) - case (UnspecifiedDistribution, Seq(), child) => - child - case (UnspecifiedDistribution, rowOrdering, child) => - sqlContext.planner.BasicOperators.getSortOperator(rowOrdering, global = false, child) + val fixedChildren = requirements.zipped.map { + case (AllTuples, rowOrdering, child) => + addOperatorsIfNecessary(SinglePartition, rowOrdering, child) + case (ClusteredDistribution(clustering), rowOrdering, child) => + addOperatorsIfNecessary(HashPartitioning(clustering, numPartitions), rowOrdering, child) + case (OrderedDistribution(ordering), rowOrdering, child) => + addOperatorsIfNecessary(RangePartitioning(ordering, numPartitions), rowOrdering, child) - case (dist, ordering, _) => - sys.error(s"Don't know how to ensure $dist with ordering $ordering") - } + case (UnspecifiedDistribution, Seq(), child) => + child + case (UnspecifiedDistribution, rowOrdering, child) => + sqlContext.planner.BasicOperators.getSortOperator(rowOrdering, global = false, child) - operator.withNewChildren(fixedChildren) + case (dist, ordering, _) => + sys.error(s"Don't know how to ensure $dist with ordering $ordering") } + + operator.withNewChildren(fixedChildren) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala index 5ad4691a5ca07..d851eae3fcc71 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution -import org.apache.spark.TaskContext +import org.apache.spark.{SparkEnv, TaskContext} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow @@ -108,7 +108,7 @@ case class GeneratedAggregate( Add( Coalesce(currentSum :: zero :: Nil), Cast(expr, calcType) - ) :: currentSum :: zero :: Nil) + ) :: currentSum :: Nil) val result = expr.dataType match { case DecimalType.Fixed(_, _) => @@ -118,45 +118,6 @@ case class GeneratedAggregate( AggregateEvaluation(currentSum :: Nil, initialValue :: Nil, updateFunction :: Nil, result) - case cs @ CombineSum(expr) => - val calcType = - expr.dataType match { - case DecimalType.Fixed(p, s) => - DecimalType.bounded(p + 10, s) - case _ => - expr.dataType - } - - val currentSum = AttributeReference("currentSum", calcType, nullable = true)() - val initialValue = Literal.create(null, calcType) - - // Coalesce avoids double calculation... - // but really, common sub expression elimination would be better.... - val zero = Cast(Literal(0), calcType) - // If we're evaluating UnscaledValue(x), we can do Count on x directly, since its - // UnscaledValue will be null if and only if x is null; helps with Average on decimals - val actualExpr = expr match { - case UnscaledValue(e) => e - case _ => expr - } - // partial sum result can be null only when no input rows present - val updateFunction = If( - IsNotNull(actualExpr), - Coalesce( - Add( - Coalesce(currentSum :: zero :: Nil), - Cast(expr, calcType)) :: currentSum :: zero :: Nil), - currentSum) - - val result = - expr.dataType match { - case DecimalType.Fixed(_, _) => - Cast(currentSum, cs.dataType) - case _ => currentSum - } - - AggregateEvaluation(currentSum :: Nil, initialValue :: Nil, updateFunction :: Nil, result) - case m @ Max(expr) => val currentMax = AttributeReference("currentMax", expr.dataType, nullable = true)() val initialValue = Literal.create(null, expr.dataType) @@ -241,7 +202,7 @@ case class GeneratedAggregate( val schemaSupportsUnsafe: Boolean = { UnsafeFixedWidthAggregationMap.supportsAggregationBufferSchema(aggregationBufferSchema) && - UnsafeFixedWidthAggregationMap.supportsGroupKeySchema(groupKeySchema) + UnsafeProjection.canSupport(groupKeySchema) } child.execute().mapPartitions { iter => @@ -299,12 +260,14 @@ case class GeneratedAggregate( } else if (unsafeEnabled && schemaSupportsUnsafe) { assert(iter.hasNext, "There should be at least one row for this path") log.info("Using Unsafe-based aggregator") + val pageSizeBytes = SparkEnv.get.conf.getSizeAsBytes("spark.buffer.pageSize", "64m") val aggregationMap = new UnsafeFixedWidthAggregationMap( newAggregationBuffer(EmptyRow), aggregationBufferSchema, groupKeySchema, TaskContext.get.taskMemoryManager(), 1024 * 16, // initial capacity + pageSizeBytes, false // disable tracking of performance metrics ) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala index 2dee3542d6101..a2145b185ce90 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala @@ -18,10 +18,8 @@ package org.apache.spark.sql.execution -import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.SortOrder import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.collection.unsafe.sort.{PrefixComparators, PrefixComparator} @@ -37,61 +35,15 @@ object SortPrefixUtils { def getPrefixComparator(sortOrder: SortOrder): PrefixComparator = { sortOrder.dataType match { - case StringType => PrefixComparators.STRING - case BooleanType | ByteType | ShortType | IntegerType | LongType => PrefixComparators.INTEGRAL - case FloatType => PrefixComparators.FLOAT - case DoubleType => PrefixComparators.DOUBLE + case StringType if sortOrder.isAscending => PrefixComparators.STRING + case StringType if !sortOrder.isAscending => PrefixComparators.STRING_DESC + case BooleanType | ByteType | ShortType | IntegerType | LongType if sortOrder.isAscending => + PrefixComparators.LONG + case BooleanType | ByteType | ShortType | IntegerType | LongType if !sortOrder.isAscending => + PrefixComparators.LONG_DESC + case FloatType | DoubleType if sortOrder.isAscending => PrefixComparators.DOUBLE + case FloatType | DoubleType if !sortOrder.isAscending => PrefixComparators.DOUBLE_DESC case _ => NoOpPrefixComparator } } - - def getPrefixComputer(sortOrder: SortOrder): InternalRow => Long = { - sortOrder.dataType match { - case StringType => (row: InternalRow) => { - PrefixComparators.STRING.computePrefix(sortOrder.child.eval(row).asInstanceOf[UTF8String]) - } - case BooleanType => - (row: InternalRow) => { - val exprVal = sortOrder.child.eval(row) - if (exprVal == null) PrefixComparators.INTEGRAL.NULL_PREFIX - else if (sortOrder.child.eval(row).asInstanceOf[Boolean]) 1 - else 0 - } - case ByteType => - (row: InternalRow) => { - val exprVal = sortOrder.child.eval(row) - if (exprVal == null) PrefixComparators.INTEGRAL.NULL_PREFIX - else sortOrder.child.eval(row).asInstanceOf[Byte] - } - case ShortType => - (row: InternalRow) => { - val exprVal = sortOrder.child.eval(row) - if (exprVal == null) PrefixComparators.INTEGRAL.NULL_PREFIX - else sortOrder.child.eval(row).asInstanceOf[Short] - } - case IntegerType => - (row: InternalRow) => { - val exprVal = sortOrder.child.eval(row) - if (exprVal == null) PrefixComparators.INTEGRAL.NULL_PREFIX - else sortOrder.child.eval(row).asInstanceOf[Int] - } - case LongType => - (row: InternalRow) => { - val exprVal = sortOrder.child.eval(row) - if (exprVal == null) PrefixComparators.INTEGRAL.NULL_PREFIX - else sortOrder.child.eval(row).asInstanceOf[Long] - } - case FloatType => (row: InternalRow) => { - val exprVal = sortOrder.child.eval(row) - if (exprVal == null) PrefixComparators.FLOAT.NULL_PREFIX - else PrefixComparators.FLOAT.computePrefix(sortOrder.child.eval(row).asInstanceOf[Float]) - } - case DoubleType => (row: InternalRow) => { - val exprVal = sortOrder.child.eval(row) - if (exprVal == null) PrefixComparators.DOUBLE.NULL_PREFIX - else PrefixComparators.DOUBLE.computePrefix(sortOrder.child.eval(row).asInstanceOf[Double]) - } - case _ => (row: InternalRow) => 0L - } - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala index c808442a4849b..e5bbd0aaed0a5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala @@ -298,7 +298,7 @@ private[sql] object SparkSqlSerializer2 { out.writeByte(NULL) } else { out.writeByte(NOT_NULL) - val value = row.getDecimal(i) + val value = row.getDecimal(i, decimal.precision, decimal.scale) val javaBigDecimal = value.toJavaBigDecimal // First, write out the unscaled value. val bytes: Array[Byte] = javaBigDecimal.unscaledValue().toByteArray diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 306bbfec624c0..03d24a88d4ecd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression2 +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression2, Utils} import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.{BroadcastHint, LogicalPlan} @@ -193,15 +193,19 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case _ => Nil } - def canBeConvertedToNewAggregation(plan: LogicalPlan): Boolean = { - aggregate.Utils.tryConvert( - plan, - sqlContext.conf.useSqlAggregate2, - sqlContext.conf.codegenEnabled).isDefined + def canBeConvertedToNewAggregation(plan: LogicalPlan): Boolean = plan match { + case a: logical.Aggregate => + if (sqlContext.conf.useSqlAggregate2 && sqlContext.conf.codegenEnabled) { + a.newAggregation.isDefined + } else { + Utils.checkInvalidAggregateFunction2(a) + false + } + case _ => false } def canBeCodeGened(aggs: Seq[AggregateExpression1]): Boolean = aggs.forall { - case _: CombineSum | _: Sum | _: Count | _: Max | _: Min | _: CombineSetsAndCount => true + case _: Sum | _: Count | _: Max | _: Min | _: CombineSetsAndCount => true // The generated set implementation is pretty limited ATM. case CollectHashSet(exprs) if exprs.size == 1 && Seq(IntegerType, LongType).contains(exprs.head.dataType) => true @@ -217,12 +221,9 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { */ object Aggregation extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case p: logical.Aggregate => - val converted = - aggregate.Utils.tryConvert( - p, - sqlContext.conf.useSqlAggregate2, - sqlContext.conf.codegenEnabled) + case p: logical.Aggregate if sqlContext.conf.useSqlAggregate2 && + sqlContext.conf.codegenEnabled => + val converted = p.newAggregation converted match { case None => Nil // Cannot convert to new aggregation code path. case Some(logical.Aggregate(groupingExpressions, resultExpressions, child)) => @@ -339,8 +340,9 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { * if necessary. */ def getSortOperator(sortExprs: Seq[SortOrder], global: Boolean, child: SparkPlan): SparkPlan = { - if (sqlContext.conf.unsafeEnabled && UnsafeExternalSort.supportsSchema(child.schema)) { - execution.UnsafeExternalSort(sortExprs, global, child) + if (sqlContext.conf.unsafeEnabled && sqlContext.conf.codegenEnabled && + TungstenSort.supportsSchema(child.schema)) { + execution.TungstenSort(sortExprs, global, child) } else if (sqlContext.conf.externalSortEnabled) { execution.ExternalSort(sortExprs, global, child) } else { @@ -363,23 +365,27 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case logical.Sort(sortExprs, global, child) => getSortOperator(sortExprs, global, planLater(child)):: Nil case logical.Project(projectList, child) => - execution.Project(projectList, planLater(child)) :: Nil + // If unsafe mode is enabled and we support these data types in Unsafe, use the + // Tungsten project. Otherwise, use the normal project. + if (sqlContext.conf.unsafeEnabled && + UnsafeProjection.canSupport(projectList) && UnsafeProjection.canSupport(child.schema)) { + execution.TungstenProject(projectList, planLater(child)) :: Nil + } else { + execution.Project(projectList, planLater(child)) :: Nil + } case logical.Filter(condition, child) => execution.Filter(condition, planLater(child)) :: Nil case e @ logical.Expand(_, _, _, child) => execution.Expand(e.projections, e.output, planLater(child)) :: Nil case a @ logical.Aggregate(group, agg, child) => { - val useNewAggregation = - aggregate.Utils.tryConvert( - a, - sqlContext.conf.useSqlAggregate2, - sqlContext.conf.codegenEnabled).isDefined - if (useNewAggregation) { + val useNewAggregation = sqlContext.conf.useSqlAggregate2 && sqlContext.conf.codegenEnabled + if (useNewAggregation && a.newAggregation.isDefined) { // If this logical.Aggregate can be planned to use new aggregation code path // (i.e. it can be planned by the Strategy Aggregation), we will not use the old // aggregation code path. Nil } else { + Utils.checkInvalidAggregateFunction2(a) execution.Aggregate(partial = false, group, agg, planLater(child)) :: Nil } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/aggregateOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/aggregateOperators.scala index 0c9082897f390..98538c462bc89 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/aggregateOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/aggregateOperators.scala @@ -72,8 +72,10 @@ case class Aggregate2Sort( protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") { child.execute().mapPartitions { iter => if (aggregateExpressions.length == 0) { - new GroupingIterator( + new FinalSortAggregationIterator( groupingExpressions, + Nil, + Nil, resultExpressions, newMutableProjection, child.output, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/sortBasedIterators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/sortBasedIterators.scala index 1b89edafa8dad..2ca0cb82c1aab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/sortBasedIterators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/sortBasedIterators.scala @@ -41,7 +41,8 @@ private[sql] abstract class SortAggregationIterator( /////////////////////////////////////////////////////////////////////////// protected val aggregateFunctions: Array[AggregateFunction2] = { - var bufferOffset = initialBufferOffset + var mutableBufferOffset = 0 + var inputBufferOffset: Int = initialInputBufferOffset val functions = new Array[AggregateFunction2](aggregateExpressions.length) var i = 0 while (i < aggregateExpressions.length) { @@ -54,13 +55,18 @@ private[sql] abstract class SortAggregationIterator( // function's children in the update method of this aggregate function. // Those eval calls require BoundReferences to work. BindReferences.bindReference(func, inputAttributes) - case _ => func + case _ => + // We only need to set inputBufferOffset for aggregate functions with mode + // PartialMerge and Final. + func.inputBufferOffset = inputBufferOffset + inputBufferOffset += func.bufferSchema.length + func } - // Set bufferOffset for this function. It is important that setting bufferOffset - // happens after all potential bindReference operations because bindReference - // will create a new instance of the function. - funcWithBoundReferences.bufferOffset = bufferOffset - bufferOffset += funcWithBoundReferences.bufferSchema.length + // Set mutableBufferOffset for this function. It is important that setting + // mutableBufferOffset happens after all potential bindReference operations + // because bindReference will create a new instance of the function. + funcWithBoundReferences.mutableBufferOffset = mutableBufferOffset + mutableBufferOffset += funcWithBoundReferences.bufferSchema.length functions(i) = funcWithBoundReferences i += 1 } @@ -97,25 +103,24 @@ private[sql] abstract class SortAggregationIterator( // The number of elements of the underlying buffer of this operator. // All aggregate functions are sharing this underlying buffer and they find their // buffer values through bufferOffset. - var size = initialBufferOffset - var i = 0 - while (i < aggregateFunctions.length) { - size += aggregateFunctions(i).bufferSchema.length - i += 1 - } - new GenericMutableRow(size) + // var size = 0 + // var i = 0 + // while (i < aggregateFunctions.length) { + // size += aggregateFunctions(i).bufferSchema.length + // i += 1 + // } + new GenericMutableRow(aggregateFunctions.map(_.bufferSchema.length).sum) } protected val joinedRow = new JoinedRow - protected val placeholderExpressions = Seq.fill(initialBufferOffset)(NoOp) - // This projection is used to initialize buffer values for all AlgebraicAggregates. protected val algebraicInitialProjection = { - val initExpressions = placeholderExpressions ++ aggregateFunctions.flatMap { + val initExpressions = aggregateFunctions.flatMap { case ae: AlgebraicAggregate => ae.initialValues case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp) } + newMutableProjection(initExpressions, Nil)().target(buffer) } @@ -132,10 +137,6 @@ private[sql] abstract class SortAggregationIterator( // Indicates if we has new group of rows to process. protected var hasNewGroup: Boolean = true - /////////////////////////////////////////////////////////////////////////// - // Private methods - /////////////////////////////////////////////////////////////////////////// - /** Initializes buffer values for all aggregate functions. */ protected def initializeBuffer(): Unit = { algebraicInitialProjection(EmptyRow) @@ -160,6 +161,10 @@ private[sql] abstract class SortAggregationIterator( } } + /////////////////////////////////////////////////////////////////////////// + // Private methods + /////////////////////////////////////////////////////////////////////////// + /** Processes rows in the current group. It will stop when it find a new group. */ private def processCurrentGroup(): Unit = { currentGroupingKey = nextGroupingKey @@ -218,10 +223,13 @@ private[sql] abstract class SortAggregationIterator( // Methods that need to be implemented /////////////////////////////////////////////////////////////////////////// - protected def initialBufferOffset: Int + /** The initial input buffer offset for `inputBufferOffset` of an [[AggregateFunction2]]. */ + protected def initialInputBufferOffset: Int + /** The function used to process an input row. */ protected def processRow(row: InternalRow): Unit + /** The function used to generate the result row. */ protected def generateOutput(): InternalRow /////////////////////////////////////////////////////////////////////////// @@ -231,37 +239,6 @@ private[sql] abstract class SortAggregationIterator( initialize() } -/** - * An iterator only used to group input rows according to values of `groupingExpressions`. - * It assumes that input rows are already grouped by values of `groupingExpressions`. - */ -class GroupingIterator( - groupingExpressions: Seq[NamedExpression], - resultExpressions: Seq[NamedExpression], - newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection), - inputAttributes: Seq[Attribute], - inputIter: Iterator[InternalRow]) - extends SortAggregationIterator( - groupingExpressions, - Nil, - newMutableProjection, - inputAttributes, - inputIter) { - - private val resultProjection = - newMutableProjection(resultExpressions, groupingExpressions.map(_.toAttribute))() - - override protected def initialBufferOffset: Int = 0 - - override protected def processRow(row: InternalRow): Unit = { - // Since we only do grouping, there is nothing to do at here. - } - - override protected def generateOutput(): InternalRow = { - resultProjection(currentGroupingKey) - } -} - /** * An iterator used to do partial aggregations (for those aggregate functions with mode Partial). * It assumes that input rows are already grouped by values of `groupingExpressions`. @@ -291,7 +268,7 @@ class PartialSortAggregationIterator( newMutableProjection(updateExpressions, bufferSchema ++ inputAttributes)().target(buffer) } - override protected def initialBufferOffset: Int = 0 + override protected def initialInputBufferOffset: Int = 0 override protected def processRow(row: InternalRow): Unit = { // Process all algebraic aggregate functions. @@ -318,11 +295,7 @@ class PartialSortAggregationIterator( * |groupingExpr1|...|groupingExprN|aggregationBuffer1|...|aggregationBufferN| * * The format of its internal buffer is: - * |placeholder1|...|placeholderN|aggregationBuffer1|...|aggregationBufferN| - * Every placeholder is for a grouping expression. - * The actual buffers are stored after placeholderN. - * The reason that we have placeholders at here is to make our underlying buffer have the same - * length with a input row. + * |aggregationBuffer1|...|aggregationBufferN| * * The format of its output rows is: * |groupingExpr1|...|groupingExprN|aggregationBuffer1|...|aggregationBufferN| @@ -340,33 +313,21 @@ class PartialMergeSortAggregationIterator( inputAttributes, inputIter) { - private val placeholderAttributes = - Seq.fill(initialBufferOffset)(AttributeReference("placeholder", NullType)()) - // This projection is used to merge buffer values for all AlgebraicAggregates. private val algebraicMergeProjection = { - val bufferSchemata = - placeholderAttributes ++ aggregateFunctions.flatMap(_.bufferAttributes) ++ - placeholderAttributes ++ aggregateFunctions.flatMap(_.cloneBufferAttributes) - val mergeExpressions = placeholderExpressions ++ aggregateFunctions.flatMap { + val mergeInputSchema = + aggregateFunctions.flatMap(_.bufferAttributes) ++ + groupingExpressions.map(_.toAttribute) ++ + aggregateFunctions.flatMap(_.cloneBufferAttributes) + val mergeExpressions = aggregateFunctions.flatMap { case ae: AlgebraicAggregate => ae.mergeExpressions case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp) } - newMutableProjection(mergeExpressions, bufferSchemata)() + newMutableProjection(mergeExpressions, mergeInputSchema)() } - // This projection is used to extract aggregation buffers from the underlying buffer. - // We need it because the underlying buffer has placeholders at its beginning. - private val extractsBufferValues = { - val expressions = aggregateFunctions.flatMap { - case agg => agg.bufferAttributes - } - - newMutableProjection(expressions, inputAttributes)() - } - - override protected def initialBufferOffset: Int = groupingExpressions.length + override protected def initialInputBufferOffset: Int = groupingExpressions.length override protected def processRow(row: InternalRow): Unit = { // Process all algebraic aggregate functions. @@ -381,7 +342,7 @@ class PartialMergeSortAggregationIterator( override protected def generateOutput(): InternalRow = { // We output grouping expressions and aggregation buffers. - joinedRow(currentGroupingKey, extractsBufferValues(buffer)) + joinedRow(currentGroupingKey, buffer).copy() } } @@ -393,11 +354,7 @@ class PartialMergeSortAggregationIterator( * |groupingExpr1|...|groupingExprN|aggregationBuffer1|...|aggregationBufferN| * * The format of its internal buffer is: - * |placeholder1|...|placeholder N|aggregationBuffer1|...|aggregationBufferN| - * Every placeholder is for a grouping expression. - * The actual buffers are stored after placeholderN. - * The reason that we have placeholders at here is to make our underlying buffer have the same - * length with a input row. + * |aggregationBuffer1|...|aggregationBufferN| * * The format of its output rows is represented by the schema of `resultExpressions`. */ @@ -425,27 +382,23 @@ class FinalSortAggregationIterator( newMutableProjection( resultExpressions, groupingExpressions.map(_.toAttribute) ++ aggregateAttributes)() - private val offsetAttributes = - Seq.fill(initialBufferOffset)(AttributeReference("placeholder", NullType)()) - // This projection is used to merge buffer values for all AlgebraicAggregates. private val algebraicMergeProjection = { - val bufferSchemata = - offsetAttributes ++ aggregateFunctions.flatMap(_.bufferAttributes) ++ - offsetAttributes ++ aggregateFunctions.flatMap(_.cloneBufferAttributes) - val mergeExpressions = placeholderExpressions ++ aggregateFunctions.flatMap { + val mergeInputSchema = + aggregateFunctions.flatMap(_.bufferAttributes) ++ + groupingExpressions.map(_.toAttribute) ++ + aggregateFunctions.flatMap(_.cloneBufferAttributes) + val mergeExpressions = aggregateFunctions.flatMap { case ae: AlgebraicAggregate => ae.mergeExpressions case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp) } - newMutableProjection(mergeExpressions, bufferSchemata)() + newMutableProjection(mergeExpressions, mergeInputSchema)() } // This projection is used to evaluate all AlgebraicAggregates. private val algebraicEvalProjection = { - val bufferSchemata = - offsetAttributes ++ aggregateFunctions.flatMap(_.bufferAttributes) ++ - offsetAttributes ++ aggregateFunctions.flatMap(_.cloneBufferAttributes) + val bufferSchemata = aggregateFunctions.flatMap(_.bufferAttributes) val evalExpressions = aggregateFunctions.map { case ae: AlgebraicAggregate => ae.evaluateExpression case agg: AggregateFunction2 => NoOp @@ -454,7 +407,7 @@ class FinalSortAggregationIterator( newMutableProjection(evalExpressions, bufferSchemata)() } - override protected def initialBufferOffset: Int = groupingExpressions.length + override protected def initialInputBufferOffset: Int = groupingExpressions.length override def initialize(): Unit = { if (inputIter.hasNext) { @@ -471,7 +424,10 @@ class FinalSortAggregationIterator( // Right now, the buffer only contains initial buffer values. Because // merging two buffers with initial values will generate a row that // still store initial values. We set the currentRow as the copy of the current buffer. - val currentRow = buffer.copy() + // Because input aggregation buffer has initialInputBufferOffset extra values at the + // beginning, we create a dummy row for this part. + val currentRow = + joinedRow(new GenericInternalRow(initialInputBufferOffset), buffer).copy() nextGroupingKey = groupGenerator(currentRow).copy() firstRowInNextGroup = currentRow } else { @@ -518,18 +474,15 @@ class FinalSortAggregationIterator( * Final mode. * * The format of its internal buffer is: - * |placeholder1|...|placeholder(N+M)|aggregationBuffer1|...|aggregationBuffer(N+M)| - * The first N placeholders represent slots of grouping expressions. - * Then, next M placeholders represent slots of col1 to colM. + * |aggregationBuffer1|...|aggregationBuffer(N+M)| * For aggregation buffers, first N aggregation buffers are used by N aggregate functions with * mode Final. Then, the last M aggregation buffers are used by M aggregate functions with mode - * Complete. The reason that we have placeholders at here is to make our underlying buffer - * have the same length with a input row. + * Complete. * * The format of its output rows is represented by the schema of `resultExpressions`. */ class FinalAndCompleteSortAggregationIterator( - override protected val initialBufferOffset: Int, + override protected val initialInputBufferOffset: Int, groupingExpressions: Seq[NamedExpression], finalAggregateExpressions: Seq[AggregateExpression2], finalAggregateAttributes: Seq[Attribute], @@ -561,9 +514,6 @@ class FinalAndCompleteSortAggregationIterator( newMutableProjection(resultExpressions, inputSchema)() } - private val offsetAttributes = - Seq.fill(initialBufferOffset)(AttributeReference("placeholder", NullType)()) - // All aggregate functions with mode Final. private val finalAggregateFunctions: Array[AggregateFunction2] = { val functions = new Array[AggregateFunction2](finalAggregateExpressions.length) @@ -601,38 +551,38 @@ class FinalAndCompleteSortAggregationIterator( // This projection is used to merge buffer values for all AlgebraicAggregates with mode // Final. private val finalAlgebraicMergeProjection = { - val numCompleteOffsetAttributes = - completeAggregateFunctions.map(_.bufferAttributes.length).sum - val completeOffsetAttributes = - Seq.fill(numCompleteOffsetAttributes)(AttributeReference("placeholder", NullType)()) - val completeOffsetExpressions = Seq.fill(numCompleteOffsetAttributes)(NoOp) - - val bufferSchemata = - offsetAttributes ++ finalAggregateFunctions.flatMap(_.bufferAttributes) ++ - completeOffsetAttributes ++ offsetAttributes ++ - finalAggregateFunctions.flatMap(_.cloneBufferAttributes) ++ completeOffsetAttributes + // The first initialInputBufferOffset values of the input aggregation buffer is + // for grouping expressions and distinct columns. + val groupingAttributesAndDistinctColumns = inputAttributes.take(initialInputBufferOffset) + + val completeOffsetExpressions = + Seq.fill(completeAggregateFunctions.map(_.bufferAttributes.length).sum)(NoOp) + + val mergeInputSchema = + finalAggregateFunctions.flatMap(_.bufferAttributes) ++ + completeAggregateFunctions.flatMap(_.bufferAttributes) ++ + groupingAttributesAndDistinctColumns ++ + finalAggregateFunctions.flatMap(_.cloneBufferAttributes) val mergeExpressions = - placeholderExpressions ++ finalAggregateFunctions.flatMap { + finalAggregateFunctions.flatMap { case ae: AlgebraicAggregate => ae.mergeExpressions case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp) } ++ completeOffsetExpressions - - newMutableProjection(mergeExpressions, bufferSchemata)() + newMutableProjection(mergeExpressions, mergeInputSchema)() } // This projection is used to update buffer values for all AlgebraicAggregates with mode // Complete. private val completeAlgebraicUpdateProjection = { - val numFinalOffsetAttributes = finalAggregateFunctions.map(_.bufferAttributes.length).sum - val finalOffsetAttributes = - Seq.fill(numFinalOffsetAttributes)(AttributeReference("placeholder", NullType)()) - val finalOffsetExpressions = Seq.fill(numFinalOffsetAttributes)(NoOp) + // We do not touch buffer values of aggregate functions with the Final mode. + val finalOffsetExpressions = + Seq.fill(finalAggregateFunctions.map(_.bufferAttributes.length).sum)(NoOp) val bufferSchema = - offsetAttributes ++ finalOffsetAttributes ++ + finalAggregateFunctions.flatMap(_.bufferAttributes) ++ completeAggregateFunctions.flatMap(_.bufferAttributes) val updateExpressions = - placeholderExpressions ++ finalOffsetExpressions ++ completeAggregateFunctions.flatMap { + finalOffsetExpressions ++ completeAggregateFunctions.flatMap { case ae: AlgebraicAggregate => ae.updateExpressions case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp) } @@ -641,9 +591,7 @@ class FinalAndCompleteSortAggregationIterator( // This projection is used to evaluate all AlgebraicAggregates. private val algebraicEvalProjection = { - val bufferSchemata = - offsetAttributes ++ aggregateFunctions.flatMap(_.bufferAttributes) ++ - offsetAttributes ++ aggregateFunctions.flatMap(_.cloneBufferAttributes) + val bufferSchemata = aggregateFunctions.flatMap(_.bufferAttributes) val evalExpressions = aggregateFunctions.map { case ae: AlgebraicAggregate => ae.evaluateExpression case agg: AggregateFunction2 => NoOp @@ -667,7 +615,10 @@ class FinalAndCompleteSortAggregationIterator( // Right now, the buffer only contains initial buffer values. Because // merging two buffers with initial values will generate a row that // still store initial values. We set the currentRow as the copy of the current buffer. - val currentRow = buffer.copy() + // Because input aggregation buffer has initialInputBufferOffset extra values at the + // beginning, we create a dummy row for this part. + val currentRow = + joinedRow(new GenericInternalRow(initialInputBufferOffset), buffer).copy() nextGroupingKey = groupGenerator(currentRow).copy() firstRowInNextGroup = currentRow } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala index 073c45ae2f9f2..cc54319171bdb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala @@ -184,7 +184,7 @@ private[sql] case class ScalaUDAF( bufferSchema, bufferValuesToCatalystConverters, bufferValuesToScalaConverters, - bufferOffset, + inputBufferOffset, null) lazy val mutableAggregateBuffer: MutableAggregationBufferImpl = @@ -192,9 +192,16 @@ private[sql] case class ScalaUDAF( bufferSchema, bufferValuesToCatalystConverters, bufferValuesToScalaConverters, - bufferOffset, + mutableBufferOffset, null) + lazy val evalAggregateBuffer: InputAggregationBuffer = + new InputAggregationBuffer( + bufferSchema, + bufferValuesToCatalystConverters, + bufferValuesToScalaConverters, + mutableBufferOffset, + null) override def initialize(buffer: MutableRow): Unit = { mutableAggregateBuffer.underlyingBuffer = buffer @@ -217,10 +224,10 @@ private[sql] case class ScalaUDAF( udaf.merge(mutableAggregateBuffer, inputAggregateBuffer) } - override def eval(buffer: InternalRow = null): Any = { - inputAggregateBuffer.underlyingInputBuffer = buffer + override def eval(buffer: InternalRow): Any = { + evalAggregateBuffer.underlyingInputBuffer = buffer - udaf.evaluate(inputAggregateBuffer) + udaf.evaluate(evalAggregateBuffer) } override def toString: String = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala index 5bbe6c162ff4b..03635baae4a5f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala @@ -29,150 +29,6 @@ import org.apache.spark.sql.types.{StructType, MapType, ArrayType} * Utility functions used by the query planner to convert our plan to new aggregation code path. */ object Utils { - // Right now, we do not support complex types in the grouping key schema. - private def supportsGroupingKeySchema(aggregate: Aggregate): Boolean = { - val hasComplexTypes = aggregate.groupingExpressions.map(_.dataType).exists { - case array: ArrayType => true - case map: MapType => true - case struct: StructType => true - case _ => false - } - - !hasComplexTypes - } - - private def tryConvert(plan: LogicalPlan): Option[Aggregate] = plan match { - case p: Aggregate if supportsGroupingKeySchema(p) => - val converted = p.transformExpressionsDown { - case expressions.Average(child) => - aggregate.AggregateExpression2( - aggregateFunction = aggregate.Average(child), - mode = aggregate.Complete, - isDistinct = false) - - case expressions.Count(child) => - aggregate.AggregateExpression2( - aggregateFunction = aggregate.Count(child), - mode = aggregate.Complete, - isDistinct = false) - - // We do not support multiple COUNT DISTINCT columns for now. - case expressions.CountDistinct(children) if children.length == 1 => - aggregate.AggregateExpression2( - aggregateFunction = aggregate.Count(children.head), - mode = aggregate.Complete, - isDistinct = true) - - case expressions.First(child) => - aggregate.AggregateExpression2( - aggregateFunction = aggregate.First(child), - mode = aggregate.Complete, - isDistinct = false) - - case expressions.Last(child) => - aggregate.AggregateExpression2( - aggregateFunction = aggregate.Last(child), - mode = aggregate.Complete, - isDistinct = false) - - case expressions.Max(child) => - aggregate.AggregateExpression2( - aggregateFunction = aggregate.Max(child), - mode = aggregate.Complete, - isDistinct = false) - - case expressions.Min(child) => - aggregate.AggregateExpression2( - aggregateFunction = aggregate.Min(child), - mode = aggregate.Complete, - isDistinct = false) - - case expressions.Sum(child) => - aggregate.AggregateExpression2( - aggregateFunction = aggregate.Sum(child), - mode = aggregate.Complete, - isDistinct = false) - - case expressions.SumDistinct(child) => - aggregate.AggregateExpression2( - aggregateFunction = aggregate.Sum(child), - mode = aggregate.Complete, - isDistinct = true) - } - // Check if there is any expressions.AggregateExpression1 left. - // If so, we cannot convert this plan. - val hasAggregateExpression1 = converted.aggregateExpressions.exists { expr => - // For every expressions, check if it contains AggregateExpression1. - expr.find { - case agg: expressions.AggregateExpression1 => true - case other => false - }.isDefined - } - - // Check if there are multiple distinct columns. - val aggregateExpressions = converted.aggregateExpressions.flatMap { expr => - expr.collect { - case agg: AggregateExpression2 => agg - } - }.toSet.toSeq - val functionsWithDistinct = aggregateExpressions.filter(_.isDistinct) - val hasMultipleDistinctColumnSets = - if (functionsWithDistinct.map(_.aggregateFunction.children).distinct.length > 1) { - true - } else { - false - } - - if (!hasAggregateExpression1 && !hasMultipleDistinctColumnSets) Some(converted) else None - - case other => None - } - - private def checkInvalidAggregateFunction2(aggregate: Aggregate): Unit = { - // If the plan cannot be converted, we will do a final round check to if the original - // logical.Aggregate contains both AggregateExpression1 and AggregateExpression2. If so, - // we need to throw an exception. - val aggregateFunction2s = aggregate.aggregateExpressions.flatMap { expr => - expr.collect { - case agg: AggregateExpression2 => agg.aggregateFunction - } - }.distinct - if (aggregateFunction2s.nonEmpty) { - // For functions implemented based on the new interface, prepare a list of function names. - val invalidFunctions = { - if (aggregateFunction2s.length > 1) { - s"${aggregateFunction2s.tail.map(_.nodeName).mkString(",")} " + - s"and ${aggregateFunction2s.head.nodeName} are" - } else { - s"${aggregateFunction2s.head.nodeName} is" - } - } - val errorMessage = - s"${invalidFunctions} implemented based on the new Aggregate Function " + - s"interface and it cannot be used with functions implemented based on " + - s"the old Aggregate Function interface." - throw new AnalysisException(errorMessage) - } - } - - def tryConvert( - plan: LogicalPlan, - useNewAggregation: Boolean, - codeGenEnabled: Boolean): Option[Aggregate] = plan match { - case p: Aggregate if useNewAggregation && codeGenEnabled => - val converted = tryConvert(p) - if (converted.isDefined) { - converted - } else { - checkInvalidAggregateFunction2(p) - None - } - case p: Aggregate => - checkInvalidAggregateFunction2(p) - None - case other => None - } - def planAggregateWithoutDistinct( groupingExpressions: Seq[Expression], aggregateExpressions: Seq[AggregateExpression2], @@ -292,8 +148,8 @@ object Utils { AggregateExpression2(aggregateFunction, PartialMerge, false) } val partialMergeAggregateAttributes = - partialMergeAggregateExpressions.map { - expr => aggregateFunctionMap(expr.aggregateFunction, expr.isDistinct) + partialMergeAggregateExpressions.flatMap { agg => + agg.aggregateFunction.bufferAttributes } val partialMergeAggregate = Aggregate2Sort( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index fe429d862a0a3..2294a670c735f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -49,6 +49,31 @@ case class Project(projectList: Seq[NamedExpression], child: SparkPlan) extends override def outputOrdering: Seq[SortOrder] = child.outputOrdering } + +/** + * A variant of [[Project]] that returns [[UnsafeRow]]s. + */ +case class TungstenProject(projectList: Seq[NamedExpression], child: SparkPlan) extends UnaryNode { + + override def outputsUnsafeRows: Boolean = true + override def canProcessUnsafeRows: Boolean = true + override def canProcessSafeRows: Boolean = true + + override def output: Seq[Attribute] = projectList.map(_.toAttribute) + + protected override def doExecute(): RDD[InternalRow] = child.execute().mapPartitions { iter => + this.transformAllExpressions { + case CreateStruct(children) => CreateStructUnsafe(children) + case CreateNamedStruct(children) => CreateNamedStructUnsafe(children) + } + val project = UnsafeProjection.create(projectList, child.output) + iter.map(project) + } + + override def outputOrdering: Seq[SortOrder] = child.outputOrdering +} + + /** * :: DeveloperApi :: */ @@ -195,137 +220,6 @@ case class TakeOrderedAndProject( override def outputOrdering: Seq[SortOrder] = sortOrder } -/** - * :: DeveloperApi :: - * Performs a sort on-heap. - * @param global when true performs a global sort of all partitions by shuffling the data first - * if necessary. - */ -@DeveloperApi -case class Sort( - sortOrder: Seq[SortOrder], - global: Boolean, - child: SparkPlan) - extends UnaryNode { - override def requiredChildDistribution: Seq[Distribution] = - if (global) OrderedDistribution(sortOrder) :: Nil else UnspecifiedDistribution :: Nil - - protected override def doExecute(): RDD[InternalRow] = attachTree(this, "sort") { - child.execute().mapPartitions( { iterator => - val ordering = newOrdering(sortOrder, child.output) - iterator.map(_.copy()).toArray.sorted(ordering).iterator - }, preservesPartitioning = true) - } - - override def output: Seq[Attribute] = child.output - - override def outputOrdering: Seq[SortOrder] = sortOrder -} - -/** - * :: DeveloperApi :: - * Performs a sort, spilling to disk as needed. - * @param global when true performs a global sort of all partitions by shuffling the data first - * if necessary. - */ -@DeveloperApi -case class ExternalSort( - sortOrder: Seq[SortOrder], - global: Boolean, - child: SparkPlan) - extends UnaryNode { - - override def requiredChildDistribution: Seq[Distribution] = - if (global) OrderedDistribution(sortOrder) :: Nil else UnspecifiedDistribution :: Nil - - protected override def doExecute(): RDD[InternalRow] = attachTree(this, "sort") { - child.execute().mapPartitions( { iterator => - val ordering = newOrdering(sortOrder, child.output) - val sorter = new ExternalSorter[InternalRow, Null, InternalRow](ordering = Some(ordering)) - sorter.insertAll(iterator.map(r => (r.copy, null))) - val baseIterator = sorter.iterator.map(_._1) - // TODO(marmbrus): The complex type signature below thwarts inference for no reason. - CompletionIterator[InternalRow, Iterator[InternalRow]](baseIterator, sorter.stop()) - }, preservesPartitioning = true) - } - - override def output: Seq[Attribute] = child.output - - override def outputOrdering: Seq[SortOrder] = sortOrder -} - -/** - * :: DeveloperApi :: - * Optimized version of [[ExternalSort]] that operates on binary data (implemented as part of - * Project Tungsten). - * - * @param global when true performs a global sort of all partitions by shuffling the data first - * if necessary. - * @param testSpillFrequency Method for configuring periodic spilling in unit tests. If set, will - * spill every `frequency` records. - */ -@DeveloperApi -case class UnsafeExternalSort( - sortOrder: Seq[SortOrder], - global: Boolean, - child: SparkPlan, - testSpillFrequency: Int = 0) - extends UnaryNode { - - private[this] val schema: StructType = child.schema - - override def requiredChildDistribution: Seq[Distribution] = - if (global) OrderedDistribution(sortOrder) :: Nil else UnspecifiedDistribution :: Nil - - protected override def doExecute(): RDD[InternalRow] = attachTree(this, "sort") { - assert(codegenEnabled, "UnsafeExternalSort requires code generation to be enabled") - def doSort(iterator: Iterator[InternalRow]): Iterator[InternalRow] = { - val ordering = newOrdering(sortOrder, child.output) - val boundSortExpression = BindReferences.bindReference(sortOrder.head, child.output) - // Hack until we generate separate comparator implementations for ascending vs. descending - // (or choose to codegen them): - val prefixComparator = { - val comp = SortPrefixUtils.getPrefixComparator(boundSortExpression) - if (sortOrder.head.direction == Descending) { - new PrefixComparator { - override def compare(p1: Long, p2: Long): Int = -1 * comp.compare(p1, p2) - } - } else { - comp - } - } - val prefixComputer = { - val prefixComputer = SortPrefixUtils.getPrefixComputer(boundSortExpression) - new UnsafeExternalRowSorter.PrefixComputer { - override def computePrefix(row: InternalRow): Long = prefixComputer(row) - } - } - val sorter = new UnsafeExternalRowSorter(schema, ordering, prefixComparator, prefixComputer) - if (testSpillFrequency > 0) { - sorter.setTestSpillFrequency(testSpillFrequency) - } - sorter.sort(iterator) - } - child.execute().mapPartitions(doSort, preservesPartitioning = true) - } - - override def output: Seq[Attribute] = child.output - - override def outputOrdering: Seq[SortOrder] = sortOrder - - override def outputsUnsafeRows: Boolean = true -} - -@DeveloperApi -object UnsafeExternalSort { - /** - * Return true if UnsafeExternalSort can sort rows with the given schema, false otherwise. - */ - def supportsSchema(schema: StructType): Boolean = { - UnsafeExternalRowSorter.supportsSchema(schema) - } -} - /** * :: DeveloperApi :: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala index e73b3704d4dfe..0cdb407ad57b9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala @@ -308,7 +308,7 @@ private[sql] object ResolvedDataSource { mode: SaveMode, options: Map[String, String], data: DataFrame): ResolvedDataSource = { - if (data.schema.map(_.dataType).exists(_.isInstanceOf[IntervalType])) { + if (data.schema.map(_.dataType).exists(_.isInstanceOf[CalendarIntervalType])) { throw new AnalysisException("Cannot save interval data type into external storage.") } val clazz: Class[_] = lookupDataSource(provider) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala index aeeb0e45270dd..f26f41fb75d57 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala @@ -158,8 +158,8 @@ package object debug { case (row: InternalRow, StructType(fields)) => row.toSeq.zip(fields.map(_.dataType)).foreach { case(d, t) => typeCheck(d, t) } - case (s: Seq[_], ArrayType(elemType, _)) => - s.foreach(typeCheck(_, elemType)) + case (a: ArrayData, ArrayType(elemType, _)) => + a.toArray().foreach(typeCheck(_, elemType)) case (m: Map[_, _], MapType(keyType, valueType, _)) => m.keys.foreach(typeCheck(_, keyType)) m.values.foreach(typeCheck(_, valueType)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/package.scala deleted file mode 100644 index 568b7ac2c5987..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/package.scala +++ /dev/null @@ -1,23 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution - -/** - * Package containing expressions that are specific to Spark runtime. - */ -package object expressions diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala index abaa4a6ce86a2..624efc1b1d734 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala @@ -62,7 +62,7 @@ case class BroadcastHashJoin( private val broadcastFuture = future { // Note that we use .execute().collect() because we don't want to convert data to Scala types val input: Array[InternalRow] = buildPlan.execute().map(_.copy()).collect() - val hashed = buildHashRelation(input.iterator) + val hashed = HashedRelation(input.iterator, buildSideKeyGenerator, input.size) sparkContext.broadcast(hashed) }(BroadcastHashJoin.broadcastHashJoinExecutionContext) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala index c9d1a880f4ef4..77e7fe71009b7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala @@ -61,7 +61,7 @@ case class BroadcastHashOuterJoin( private val broadcastFuture = future { // Note that we use .execute().collect() because we don't want to convert data to Scala types val input: Array[InternalRow] = buildPlan.execute().map(_.copy()).collect() - val hashed = buildHashRelation(input.iterator) + val hashed = HashedRelation(input.iterator, buildKeyGenerator, input.size) sparkContext.broadcast(hashed) }(BroadcastHashOuterJoin.broadcastHashOuterJoinExecutionContext) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala index f71c0ce352904..a60593911f94f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala @@ -37,17 +37,17 @@ case class BroadcastLeftSemiJoinHash( condition: Option[Expression]) extends BinaryNode with HashSemiJoin { protected override def doExecute(): RDD[InternalRow] = { - val buildIter = right.execute().map(_.copy()).collect().toIterator + val input = right.execute().map(_.copy()).collect() if (condition.isEmpty) { - val hashSet = buildKeyHashSet(buildIter) + val hashSet = buildKeyHashSet(input.toIterator) val broadcastedRelation = sparkContext.broadcast(hashSet) left.execute().mapPartitions { streamIter => hashSemiJoin(streamIter, broadcastedRelation.value) } } else { - val hashRelation = buildHashRelation(buildIter) + val hashRelation = HashedRelation(input.toIterator, rightKeyGenerator, input.size) val broadcastedRelation = sparkContext.broadcast(hashRelation) left.execute().mapPartitions { streamIter => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala index 700636966f8be..83b726a8e2897 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala @@ -47,13 +47,11 @@ case class BroadcastNestedLoopJoin( override def outputsUnsafeRows: Boolean = left.outputsUnsafeRows || right.outputsUnsafeRows override def canProcessUnsafeRows: Boolean = true - @transient private[this] lazy val resultProjection: Projection = { + @transient private[this] lazy val resultProjection: InternalRow => InternalRow = { if (outputsUnsafeRows) { UnsafeProjection.create(schema) } else { - new Projection { - override def apply(r: InternalRow): InternalRow = r - } + identity[InternalRow] } } @@ -96,7 +94,6 @@ case class BroadcastNestedLoopJoin( var streamRowMatched = false while (i < broadcastedRelation.value.size) { - // TODO: One bitset per partition instead of per row. val broadcastedRow = broadcastedRelation.value(i) buildSide match { case BuildRight if boundCondition(joinedRow(streamedRow, broadcastedRow)) => @@ -135,17 +132,26 @@ case class BroadcastNestedLoopJoin( val buf: CompactBuffer[InternalRow] = new CompactBuffer() var i = 0 val rel = broadcastedRelation.value - while (i < rel.length) { - if (!allIncludedBroadcastTuples.contains(i)) { - (joinType, buildSide) match { - case (RightOuter | FullOuter, BuildRight) => - buf += resultProjection(new JoinedRow(leftNulls, rel(i))) - case (LeftOuter | FullOuter, BuildLeft) => - buf += resultProjection(new JoinedRow(rel(i), rightNulls)) - case _ => + (joinType, buildSide) match { + case (RightOuter | FullOuter, BuildRight) => + val joinedRow = new JoinedRow + joinedRow.withLeft(leftNulls) + while (i < rel.length) { + if (!allIncludedBroadcastTuples.contains(i)) { + buf += resultProjection(joinedRow.withRight(rel(i))).copy() + } + i += 1 } - } - i += 1 + case (LeftOuter | FullOuter, BuildLeft) => + val joinedRow = new JoinedRow + joinedRow.withRight(rightNulls) + while (i < rel.length) { + if (!allIncludedBroadcastTuples.contains(i)) { + buf += resultProjection(joinedRow.withLeft(rel(i))).copy() + } + i += 1 + } + case _ => } buf.toSeq } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index 46ab5b0d1cc6d..6b3d1652923fd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql.execution.joins import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.util.collection.CompactBuffer trait HashJoin { @@ -44,16 +43,24 @@ trait HashJoin { override def output: Seq[Attribute] = left.output ++ right.output - protected[this] def supportUnsafe: Boolean = { + protected[this] def isUnsafeMode: Boolean = { (self.codegenEnabled && UnsafeProjection.canSupport(buildKeys) && UnsafeProjection.canSupport(self.schema)) } - override def outputsUnsafeRows: Boolean = supportUnsafe - override def canProcessUnsafeRows: Boolean = supportUnsafe + override def outputsUnsafeRows: Boolean = isUnsafeMode + override def canProcessUnsafeRows: Boolean = isUnsafeMode + override def canProcessSafeRows: Boolean = !isUnsafeMode + + @transient protected lazy val buildSideKeyGenerator: Projection = + if (isUnsafeMode) { + UnsafeProjection.create(buildKeys, buildPlan.output) + } else { + newMutableProjection(buildKeys, buildPlan.output)() + } @transient protected lazy val streamSideKeyGenerator: Projection = - if (supportUnsafe) { + if (isUnsafeMode) { UnsafeProjection.create(streamedKeys, streamedPlan.output) } else { newMutableProjection(streamedKeys, streamedPlan.output)() @@ -65,18 +72,16 @@ trait HashJoin { { new Iterator[InternalRow] { private[this] var currentStreamedRow: InternalRow = _ - private[this] var currentHashMatches: CompactBuffer[InternalRow] = _ + private[this] var currentHashMatches: Seq[InternalRow] = _ private[this] var currentMatchPosition: Int = -1 // Mutable per row objects. private[this] val joinRow = new JoinedRow - private[this] val resultProjection: Projection = { - if (supportUnsafe) { + private[this] val resultProjection: (InternalRow) => InternalRow = { + if (isUnsafeMode) { UnsafeProjection.create(self.schema) } else { - new Projection { - override def apply(r: InternalRow): InternalRow = r - } + identity[InternalRow] } } @@ -122,12 +127,4 @@ trait HashJoin { } } } - - protected[this] def buildHashRelation(buildIter: Iterator[InternalRow]): HashedRelation = { - if (supportUnsafe) { - UnsafeHashedRelation(buildIter, buildKeys, buildPlan) - } else { - HashedRelation(buildIter, newProjection(buildKeys, buildPlan.output)) - } - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala index 6bf2f82954046..7e671e7914f1a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala @@ -75,30 +75,36 @@ trait HashOuterJoin { s"HashOuterJoin should not take $x as the JoinType") } - protected[this] def supportUnsafe: Boolean = { + protected[this] def isUnsafeMode: Boolean = { (self.codegenEnabled && joinType != FullOuter && UnsafeProjection.canSupport(buildKeys) && UnsafeProjection.canSupport(self.schema)) } - override def outputsUnsafeRows: Boolean = supportUnsafe - override def canProcessUnsafeRows: Boolean = supportUnsafe + override def outputsUnsafeRows: Boolean = isUnsafeMode + override def canProcessUnsafeRows: Boolean = isUnsafeMode + override def canProcessSafeRows: Boolean = !isUnsafeMode - protected[this] def streamedKeyGenerator(): Projection = { - if (supportUnsafe) { + @transient protected lazy val buildKeyGenerator: Projection = + if (isUnsafeMode) { + UnsafeProjection.create(buildKeys, buildPlan.output) + } else { + newMutableProjection(buildKeys, buildPlan.output)() + } + + @transient protected[this] lazy val streamedKeyGenerator: Projection = { + if (isUnsafeMode) { UnsafeProjection.create(streamedKeys, streamedPlan.output) } else { newProjection(streamedKeys, streamedPlan.output) } } - @transient private[this] lazy val resultProjection: Projection = { - if (supportUnsafe) { + @transient private[this] lazy val resultProjection: InternalRow => InternalRow = { + if (isUnsafeMode) { UnsafeProjection.create(self.schema) } else { - new Projection { - override def apply(r: InternalRow): InternalRow = r - } + identity[InternalRow] } } @@ -230,12 +236,4 @@ trait HashOuterJoin { hashTable } - - protected[this] def buildHashRelation(buildIter: Iterator[InternalRow]): HashedRelation = { - if (supportUnsafe) { - UnsafeHashedRelation(buildIter, buildKeys, buildPlan) - } else { - HashedRelation(buildIter, newProjection(buildKeys, buildPlan.output)) - } - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala index 7f49264d40354..97fde8f975bfd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala @@ -35,11 +35,13 @@ trait HashSemiJoin { protected[this] def supportUnsafe: Boolean = { (self.codegenEnabled && UnsafeProjection.canSupport(leftKeys) && UnsafeProjection.canSupport(rightKeys) - && UnsafeProjection.canSupport(left.schema)) + && UnsafeProjection.canSupport(left.schema) + && UnsafeProjection.canSupport(right.schema)) } - override def outputsUnsafeRows: Boolean = right.outputsUnsafeRows + override def outputsUnsafeRows: Boolean = supportUnsafe override def canProcessUnsafeRows: Boolean = supportUnsafe + override def canProcessSafeRows: Boolean = !supportUnsafe @transient protected lazy val leftKeyGenerator: Projection = if (supportUnsafe) { @@ -87,14 +89,6 @@ trait HashSemiJoin { }) } - protected def buildHashRelation(buildIter: Iterator[InternalRow]): HashedRelation = { - if (supportUnsafe) { - UnsafeHashedRelation(buildIter, rightKeys, right) - } else { - HashedRelation(buildIter, newProjection(rightKeys, right.output)) - } - } - protected def hashSemiJoin( streamIter: Iterator[InternalRow], hashedRelation: HashedRelation): Iterator[InternalRow] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index 8d5731afd59b8..f88a45f48aee9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -18,12 +18,16 @@ package org.apache.spark.sql.execution.joins import java.io.{Externalizable, ObjectInput, ObjectOutput} +import java.nio.ByteOrder import java.util.{HashMap => JavaHashMap} +import org.apache.spark.{SparkConf, SparkEnv, TaskContext} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.execution.{SparkPlan, SparkSqlSerializer} -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.execution.SparkSqlSerializer +import org.apache.spark.unsafe.PlatformDependent +import org.apache.spark.unsafe.map.BytesToBytesMap +import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, MemoryAllocator, TaskMemoryManager} import org.apache.spark.util.collection.CompactBuffer @@ -32,7 +36,7 @@ import org.apache.spark.util.collection.CompactBuffer * object. */ private[joins] sealed trait HashedRelation { - def get(key: InternalRow): CompactBuffer[InternalRow] + def get(key: InternalRow): Seq[InternalRow] // This is a helper method to implement Externalizable, and is used by // GeneralHashedRelation and UniqueKeyHashedRelation @@ -59,9 +63,9 @@ private[joins] final class GeneralHashedRelation( private var hashTable: JavaHashMap[InternalRow, CompactBuffer[InternalRow]]) extends HashedRelation with Externalizable { - def this() = this(null) // Needed for serialization + private def this() = this(null) // Needed for serialization - override def get(key: InternalRow): CompactBuffer[InternalRow] = hashTable.get(key) + override def get(key: InternalRow): Seq[InternalRow] = hashTable.get(key) override def writeExternal(out: ObjectOutput): Unit = { writeBytes(out, SparkSqlSerializer.serialize(hashTable)) @@ -81,9 +85,9 @@ private[joins] final class UniqueKeyHashedRelation(private var hashTable: JavaHashMap[InternalRow, InternalRow]) extends HashedRelation with Externalizable { - def this() = this(null) // Needed for serialization + private def this() = this(null) // Needed for serialization - override def get(key: InternalRow): CompactBuffer[InternalRow] = { + override def get(key: InternalRow): Seq[InternalRow] = { val v = hashTable.get(key) if (v eq null) null else CompactBuffer(v) } @@ -109,6 +113,10 @@ private[joins] object HashedRelation { keyGenerator: Projection, sizeEstimate: Int = 64): HashedRelation = { + if (keyGenerator.isInstanceOf[UnsafeProjection]) { + return UnsafeHashedRelation(input, keyGenerator.asInstanceOf[UnsafeProjection], sizeEstimate) + } + // TODO: Use Spark's HashMap implementation. val hashTable = new JavaHashMap[InternalRow, CompactBuffer[InternalRow]](sizeEstimate) var currentRow: InternalRow = null @@ -149,31 +157,140 @@ private[joins] object HashedRelation { } } - /** - * A HashedRelation for UnsafeRow, which is backed by BytesToBytesMap that maps the key into a - * sequence of values. + * A HashedRelation for UnsafeRow, which is backed by HashMap or BytesToBytesMap that maps the key + * into a sequence of values. + * + * When it's created, it uses HashMap. After it's serialized and deserialized, it switch to use + * BytesToBytesMap for better memory performance (multiple values for the same are stored as a + * continuous byte array. * - * TODO(davies): use BytesToBytesMap + * It's serialized in the following format: + * [number of keys] + * [size of key] [size of all values in bytes] [key bytes] [bytes for all values] + * ... + * + * All the values are serialized as following: + * [number of fields] [number of bytes] [underlying bytes of UnsafeRow] + * ... */ private[joins] final class UnsafeHashedRelation( private var hashTable: JavaHashMap[UnsafeRow, CompactBuffer[UnsafeRow]]) extends HashedRelation with Externalizable { - def this() = this(null) // Needed for serialization + private[joins] def this() = this(null) // Needed for serialization + + // Use BytesToBytesMap in executor for better performance (it's created when deserialization) + @transient private[this] var binaryMap: BytesToBytesMap = _ - override def get(key: InternalRow): CompactBuffer[InternalRow] = { + override def get(key: InternalRow): Seq[InternalRow] = { val unsafeKey = key.asInstanceOf[UnsafeRow] - // Thanks to type eraser - hashTable.get(unsafeKey).asInstanceOf[CompactBuffer[InternalRow]] + + if (binaryMap != null) { + // Used in Broadcast join + val loc = binaryMap.lookup(unsafeKey.getBaseObject, unsafeKey.getBaseOffset, + unsafeKey.getSizeInBytes) + if (loc.isDefined) { + val buffer = CompactBuffer[UnsafeRow]() + + val base = loc.getValueAddress.getBaseObject + var offset = loc.getValueAddress.getBaseOffset + val last = loc.getValueAddress.getBaseOffset + loc.getValueLength + while (offset < last) { + val numFields = PlatformDependent.UNSAFE.getInt(base, offset) + val sizeInBytes = PlatformDependent.UNSAFE.getInt(base, offset + 4) + offset += 8 + + val row = new UnsafeRow + row.pointTo(base, offset, numFields, sizeInBytes) + buffer += row + offset += sizeInBytes + } + buffer + } else { + null + } + + } else { + // Use the JavaHashMap in Local mode or ShuffleHashJoin + hashTable.get(unsafeKey) + } } override def writeExternal(out: ObjectOutput): Unit = { - writeBytes(out, SparkSqlSerializer.serialize(hashTable)) + out.writeInt(hashTable.size()) + + val iter = hashTable.entrySet().iterator() + while (iter.hasNext) { + val entry = iter.next() + val key = entry.getKey + val values = entry.getValue + + // write all the values as single byte array + var totalSize = 0L + var i = 0 + while (i < values.length) { + totalSize += values(i).getSizeInBytes + 4 + 4 + i += 1 + } + assert(totalSize < Integer.MAX_VALUE, "values are too big") + + // [key size] [values size] [key bytes] [values bytes] + out.writeInt(key.getSizeInBytes) + out.writeInt(totalSize.toInt) + out.write(key.getBytes) + i = 0 + while (i < values.length) { + // [num of fields] [num of bytes] [row bytes] + // write the integer in native order, so they can be read by UNSAFE.getInt() + if (ByteOrder.nativeOrder() == ByteOrder.BIG_ENDIAN) { + out.writeInt(values(i).numFields()) + out.writeInt(values(i).getSizeInBytes) + } else { + out.writeInt(Integer.reverseBytes(values(i).numFields())) + out.writeInt(Integer.reverseBytes(values(i).getSizeInBytes)) + } + out.write(values(i).getBytes) + i += 1 + } + } } override def readExternal(in: ObjectInput): Unit = { - hashTable = SparkSqlSerializer.deserialize(readBytes(in)) + val nKeys = in.readInt() + // This is used in Broadcast, shared by multiple tasks, so we use on-heap memory + val memoryManager = new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)) + + val pageSizeBytes = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf()) + .getSizeAsBytes("spark.buffer.pageSize", "64m") + + binaryMap = new BytesToBytesMap( + memoryManager, + nKeys * 2, // reduce hash collision + pageSizeBytes) + + var i = 0 + var keyBuffer = new Array[Byte](1024) + var valuesBuffer = new Array[Byte](1024) + while (i < nKeys) { + val keySize = in.readInt() + val valuesSize = in.readInt() + if (keySize > keyBuffer.size) { + keyBuffer = new Array[Byte](keySize) + } + in.readFully(keyBuffer, 0, keySize) + if (valuesSize > valuesBuffer.size) { + valuesBuffer = new Array[Byte](valuesSize) + } + in.readFully(valuesBuffer, 0, valuesSize) + + // put it into binary map + val loc = binaryMap.lookup(keyBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, keySize) + assert(!loc.isDefined, "Duplicated key found!") + loc.putNewKey(keyBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, keySize, + valuesBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, valuesSize) + i += 1 + } } } @@ -181,33 +298,14 @@ private[joins] object UnsafeHashedRelation { def apply( input: Iterator[InternalRow], - buildKeys: Seq[Expression], - buildPlan: SparkPlan, - sizeEstimate: Int = 64): HashedRelation = { - val boundedKeys = buildKeys.map(BindReferences.bindReference(_, buildPlan.output)) - apply(input, boundedKeys, buildPlan.schema, sizeEstimate) - } - - // Used for tests - def apply( - input: Iterator[InternalRow], - buildKeys: Seq[Expression], - rowSchema: StructType, + keyGenerator: UnsafeProjection, sizeEstimate: Int): HashedRelation = { - // TODO: Use BytesToBytesMap. val hashTable = new JavaHashMap[UnsafeRow, CompactBuffer[UnsafeRow]](sizeEstimate) - val toUnsafe = UnsafeProjection.create(rowSchema) - val keyGenerator = UnsafeProjection.create(buildKeys) // Create a mapping of buildKeys -> rows while (input.hasNext) { - val currentRow = input.next() - val unsafeRow = if (currentRow.isInstanceOf[UnsafeRow]) { - currentRow.asInstanceOf[UnsafeRow] - } else { - toUnsafe(currentRow) - } + val unsafeRow = input.next().asInstanceOf[UnsafeRow] val rowKey = keyGenerator(unsafeRow) if (!rowKey.anyNull) { val existingMatchList = hashTable.get(rowKey) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala index 874712a4e739f..26a664104d6fb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala @@ -46,7 +46,7 @@ case class LeftSemiJoinHash( val hashSet = buildKeyHashSet(buildIter) hashSemiJoin(streamIter, hashSet) } else { - val hashRelation = buildHashRelation(buildIter) + val hashRelation = HashedRelation(buildIter, rightKeyGenerator) hashSemiJoin(streamIter, hashRelation) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala index 948d0ccebceb0..5439e10a60b2a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala @@ -45,7 +45,7 @@ case class ShuffledHashJoin( protected override def doExecute(): RDD[InternalRow] = { buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, streamIter) => - val hashed = buildHashRelation(buildIter) + val hashed = HashedRelation(buildIter, buildSideKeyGenerator) hashJoin(streamIter, hashed) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala index f54f1edd38ec8..d29b593207c4d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala @@ -50,8 +50,8 @@ case class ShuffledHashOuterJoin( // TODO this probably can be replaced by external sort (sort merged join?) joinType match { case LeftOuter => - val hashed = buildHashRelation(rightIter) - val keyGenerator = streamedKeyGenerator() + val hashed = HashedRelation(rightIter, buildKeyGenerator) + val keyGenerator = streamedKeyGenerator leftIter.flatMap( currentRow => { val rowKey = keyGenerator(currentRow) joinedRow.withLeft(currentRow) @@ -59,8 +59,8 @@ case class ShuffledHashOuterJoin( }) case RightOuter => - val hashed = buildHashRelation(leftIter) - val keyGenerator = streamedKeyGenerator() + val hashed = HashedRelation(leftIter, buildKeyGenerator) + val keyGenerator = streamedKeyGenerator rightIter.flatMap ( currentRow => { val rowKey = keyGenerator(currentRow) joinedRow.withRight(currentRow) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala index ec084a299649e..ef1c6e57dc08a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala @@ -134,8 +134,19 @@ object EvaluatePython { } new GenericInternalRowWithSchema(values, struct) - case (seq: Seq[Any], array: ArrayType) => - seq.map(x => toJava(x, array.elementType)).asJava + case (a: ArrayData, array: ArrayType) => + val length = a.numElements() + val values = new java.util.ArrayList[Any](length) + var i = 0 + while (i < length) { + if (a.isNullAt(i)) { + values.add(null) + } else { + values.add(toJava(a.get(i), array.elementType)) + } + i += 1 + } + values case (obj: Map[_, _], mt: MapType) => obj.map { case (k, v) => (toJava(k, mt.keyType), toJava(v, mt.valueType)) @@ -190,10 +201,10 @@ object EvaluatePython { case (c, BinaryType) if c.getClass.isArray && c.getClass.getComponentType.getName == "byte" => c case (c: java.util.List[_], ArrayType(elementType, _)) => - c.map { e => fromJava(e, elementType)}.toSeq + new GenericArrayData(c.map { e => fromJava(e, elementType)}.toArray) case (c, ArrayType(elementType, _)) if c.getClass.isArray => - c.asInstanceOf[Array[_]].map(e => fromJava(e, elementType)).toSeq + new GenericArrayData(c.asInstanceOf[Array[_]].map(e => fromJava(e, elementType))) case (c: java.util.Map[_, _], MapType(keyType, valueType, _)) => c.map { case (key, value) => (fromJava(key, keyType), fromJava(value, valueType)) @@ -267,7 +278,6 @@ object EvaluatePython { pickler.save(row.values(i)) i += 1 } - row.values.foreach(pickler.save) out.write(Opcodes.TUPLE) out.write(Opcodes.REDUCE) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala new file mode 100644 index 0000000000000..6d903ab23c57f --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala @@ -0,0 +1,151 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.errors._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.physical.{UnspecifiedDistribution, OrderedDistribution, Distribution} +import org.apache.spark.sql.types.StructType +import org.apache.spark.util.CompletionIterator +import org.apache.spark.util.collection.ExternalSorter + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// This file defines various sort operators. +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +/** + * Performs a sort on-heap. + * @param global when true performs a global sort of all partitions by shuffling the data first + * if necessary. + */ +case class Sort( + sortOrder: Seq[SortOrder], + global: Boolean, + child: SparkPlan) + extends UnaryNode { + override def requiredChildDistribution: Seq[Distribution] = + if (global) OrderedDistribution(sortOrder) :: Nil else UnspecifiedDistribution :: Nil + + protected override def doExecute(): RDD[InternalRow] = attachTree(this, "sort") { + child.execute().mapPartitions( { iterator => + val ordering = newOrdering(sortOrder, child.output) + iterator.map(_.copy()).toArray.sorted(ordering).iterator + }, preservesPartitioning = true) + } + + override def output: Seq[Attribute] = child.output + + override def outputOrdering: Seq[SortOrder] = sortOrder +} + +/** + * Performs a sort, spilling to disk as needed. + * @param global when true performs a global sort of all partitions by shuffling the data first + * if necessary. + */ +case class ExternalSort( + sortOrder: Seq[SortOrder], + global: Boolean, + child: SparkPlan) + extends UnaryNode { + + override def requiredChildDistribution: Seq[Distribution] = + if (global) OrderedDistribution(sortOrder) :: Nil else UnspecifiedDistribution :: Nil + + protected override def doExecute(): RDD[InternalRow] = attachTree(this, "sort") { + child.execute().mapPartitions( { iterator => + val ordering = newOrdering(sortOrder, child.output) + val sorter = new ExternalSorter[InternalRow, Null, InternalRow](ordering = Some(ordering)) + sorter.insertAll(iterator.map(r => (r.copy(), null))) + val baseIterator = sorter.iterator.map(_._1) + // TODO(marmbrus): The complex type signature below thwarts inference for no reason. + CompletionIterator[InternalRow, Iterator[InternalRow]](baseIterator, sorter.stop()) + }, preservesPartitioning = true) + } + + override def output: Seq[Attribute] = child.output + + override def outputOrdering: Seq[SortOrder] = sortOrder +} + +/** + * Optimized version of [[ExternalSort]] that operates on binary data (implemented as part of + * Project Tungsten). + * + * @param global when true performs a global sort of all partitions by shuffling the data first + * if necessary. + * @param testSpillFrequency Method for configuring periodic spilling in unit tests. If set, will + * spill every `frequency` records. + */ +case class TungstenSort( + sortOrder: Seq[SortOrder], + global: Boolean, + child: SparkPlan, + testSpillFrequency: Int = 0) + extends UnaryNode { + + override def outputsUnsafeRows: Boolean = true + override def canProcessUnsafeRows: Boolean = true + override def canProcessSafeRows: Boolean = false + + override def output: Seq[Attribute] = child.output + + override def outputOrdering: Seq[SortOrder] = sortOrder + + override def requiredChildDistribution: Seq[Distribution] = + if (global) OrderedDistribution(sortOrder) :: Nil else UnspecifiedDistribution :: Nil + + protected override def doExecute(): RDD[InternalRow] = { + val schema = child.schema + val childOutput = child.output + child.execute().mapPartitions({ iter => + val ordering = newOrdering(sortOrder, childOutput) + + // The comparator for comparing prefix + val boundSortExpression = BindReferences.bindReference(sortOrder.head, childOutput) + val prefixComparator = SortPrefixUtils.getPrefixComparator(boundSortExpression) + + // The generator for prefix + val prefixProjection = UnsafeProjection.create(Seq(SortPrefix(boundSortExpression))) + val prefixComputer = new UnsafeExternalRowSorter.PrefixComputer { + override def computePrefix(row: InternalRow): Long = { + prefixProjection.apply(row).getLong(0) + } + } + + val sorter = new UnsafeExternalRowSorter(schema, ordering, prefixComparator, prefixComputer) + if (testSpillFrequency > 0) { + sorter.setTestSpillFrequency(testSpillFrequency) + } + sorter.sort(iter.asInstanceOf[Iterator[UnsafeRow]]) + }, preservesPartitioning = true) + } + +} + +object TungstenSort { + /** + * Return true if UnsafeExternalSort can sort rows with the given schema, false otherwise. + */ + def supportsSchema(schema: StructType): Boolean = { + UnsafeExternalRowSorter.supportsSchema(schema) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala index 78da2840dad69..9329148aa233c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala @@ -22,7 +22,7 @@ import scala.collection.mutable.{Map => MutableMap} import org.apache.spark.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.plans.logical.LocalRelation -import org.apache.spark.sql.types.{DataType, ArrayType, StructField, StructType} +import org.apache.spark.sql.types._ import org.apache.spark.sql.{Column, DataFrame} private[sql] object FrequentItems extends Logging { @@ -110,7 +110,7 @@ private[sql] object FrequentItems extends Logging { baseCounts } ) - val justItems = freqItems.map(m => m.baseMap.keys.toSeq) + val justItems = freqItems.map(m => m.baseMap.keys.toArray).map(new GenericArrayData(_)) val resultRow = InternalRow(justItems : _*) // append frequent Items to the column name for easy debugging val outputCols = colInfo.map { v => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index cab3db609dd4b..46dc4605a5ccb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -634,7 +634,7 @@ object functions { * @group normal_funcs * @since 1.4.0 */ - def monotonicallyIncreasingId(): Column = execution.expressions.MonotonicallyIncreasingID() + def monotonicallyIncreasingId(): Column = MonotonicallyIncreasingID() /** * Return an alternative value `r` if `l` is NaN. @@ -741,7 +741,16 @@ object functions { * @group normal_funcs * @since 1.4.0 */ - def sparkPartitionId(): Column = execution.expressions.SparkPartitionID + def sparkPartitionId(): Column = SparkPartitionID() + + /** + * The file name of the current Spark task + * + * Note that this is indeterministic becuase it depends on what is currently being read in. + * + * @group normal_funcs + */ + def inputFileName(): Column = InputFileName() /** * Computes the square root of the specified float value. @@ -1423,7 +1432,8 @@ object functions { def round(columnName: String): Column = round(Column(columnName), 0) /** - * Returns the value of `e` rounded to `scale` decimal places. + * Round the value of `e` to `scale` decimal places if `scale` >= 0 + * or at integral part when `scale` < 0. * * @group math_funcs * @since 1.5.0 @@ -1431,7 +1441,8 @@ object functions { def round(e: Column, scale: Int): Column = Round(e.expr, Literal(scale)) /** - * Returns the value of the given column rounded to `scale` decimal places. + * Round the value of the given column to `scale` decimal places if `scale` >= 0 + * or at integral part when `scale` < 0. * * @group math_funcs * @since 1.5.0 @@ -1916,6 +1927,14 @@ object functions { // DateTime functions ////////////////////////////////////////////////////////////////////////////////////////////// + /** + * Returns the date that is numMonths after startDate. + * @group datetime_funcs + * @since 1.5.0 + */ + def add_months(startDate: Column, numMonths: Int): Column = + AddMonths(startDate.expr, Literal(numMonths)) + /** * Converts a date/timestamp/string to a value of string in the format specified by the date * format given by the second argument. @@ -1948,6 +1967,20 @@ object functions { def date_format(dateColumnName: String, format: String): Column = date_format(Column(dateColumnName), format) + /** + * Returns the date that is `days` days after `start` + * @group datetime_funcs + * @since 1.5.0 + */ + def date_add(start: Column, days: Int): Column = DateAdd(start.expr, Literal(days)) + + /** + * Returns the date that is `days` days before `start` + * @group datetime_funcs + * @since 1.5.0 + */ + def date_sub(start: Column, days: Int): Column = DateSub(start.expr, Literal(days)) + /** * Extracts the year as an integer from a given date/timestamp/string. * @group datetime_funcs @@ -2032,6 +2065,16 @@ object functions { */ def hour(columnName: String): Column = hour(Column(columnName)) + /** + * Given a date column, returns the last day of the month which the given date belongs to. + * For example, input "2015-07-27" returns "2015-07-31" since July 31 is the last day of the + * month in July 2015. + * + * @group datetime_funcs + * @since 1.5.0 + */ + def last_day(e: Column): Column = LastDay(e.expr) + /** * Extracts the minutes as an integer from a given date/timestamp/string. * @group datetime_funcs @@ -2046,6 +2089,28 @@ object functions { */ def minute(columnName: String): Column = minute(Column(columnName)) + /* + * Returns number of months between dates `date1` and `date2`. + * @group datetime_funcs + * @since 1.5.0 + */ + def months_between(date1: Column, date2: Column): Column = MonthsBetween(date1.expr, date2.expr) + + /** + * Given a date column, returns the first date which is later than the value of the date column + * that is on the specified day of the week. + * + * For example, `next_day('2015-07-27', "Sunday")` returns 2015-08-02 because that is the first + * Sunday after 2015-07-27. + * + * Day of the week parameter is case insensitive, and accepts: + * "Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun". + * + * @group datetime_funcs + * @since 1.5.0 + */ + def next_day(date: Column, dayOfWeek: String): Column = NextDay(date.expr, lit(dayOfWeek).expr) + /** * Extracts the seconds as an integer from a given date/timestamp/string. * @group datetime_funcs @@ -2074,6 +2139,64 @@ object functions { */ def weekofyear(columnName: String): Column = weekofyear(Column(columnName)) + /** + * Converts the number of seconds from unix epoch (1970-01-01 00:00:00 UTC) to a string + * representing the timestamp of that moment in the current system time zone in the given + * format. + * @group datetime_funcs + * @since 1.5.0 + */ + def from_unixtime(ut: Column): Column = FromUnixTime(ut.expr, Literal("yyyy-MM-dd HH:mm:ss")) + + /** + * Converts the number of seconds from unix epoch (1970-01-01 00:00:00 UTC) to a string + * representing the timestamp of that moment in the current system time zone in the given + * format. + * @group datetime_funcs + * @since 1.5.0 + */ + def from_unixtime(ut: Column, f: String): Column = FromUnixTime(ut.expr, Literal(f)) + + /** + * Gets current Unix timestamp in seconds. + * @group datetime_funcs + * @since 1.5.0 + */ + def unix_timestamp(): Column = UnixTimestamp(CurrentTimestamp(), Literal("yyyy-MM-dd HH:mm:ss")) + + /** + * Converts time string in format yyyy-MM-dd HH:mm:ss to Unix timestamp (in seconds), + * using the default timezone and the default locale, return null if fail. + * @group datetime_funcs + * @since 1.5.0 + */ + def unix_timestamp(s: Column): Column = UnixTimestamp(s.expr, Literal("yyyy-MM-dd HH:mm:ss")) + + /** + * Convert time string with given pattern + * (see [http://docs.oracle.com/javase/tutorial/i18n/format/simpleDateFormat.html]) + * to Unix time stamp (in seconds), return null if fail. + * @group datetime_funcs + * @since 1.5.0 + */ + def unix_timestamp(s: Column, p: String): Column = UnixTimestamp(s.expr, Literal(p)) + + /* + * Converts the column into DateType. + * + * @group datetime_funcs + * @since 1.5.0 + */ + def to_date(e: Column): Column = ToDate(e.expr) + + /** + * Returns date truncated to the unit specified by the format. + * + * @group datetime_funcs + * @since 1.5.0 + */ + def trunc(date: Column, format: String): Column = TruncDate(date.expr, Literal(format)) + ////////////////////////////////////////////////////////////////////////////////////////////// // Collection functions ////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/InferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/InferSchema.scala index 0eb3b04007f8d..04ab5e2217882 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/InferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/InferSchema.scala @@ -125,7 +125,7 @@ private[sql] object InferSchema { * Convert NullType to StringType and remove StructTypes with no fields */ private def canonicalizeType: DataType => Option[DataType] = { - case at@ArrayType(elementType, _) => + case at @ ArrayType(elementType, _) => for { canonicalType <- canonicalizeType(elementType) } yield { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala index 381e7ed54428f..1c309f8794ef3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala @@ -110,8 +110,13 @@ private[sql] object JacksonParser { case (START_OBJECT, st: StructType) => convertObject(factory, parser, st) + case (START_ARRAY, st: StructType) => + // SPARK-3308: support reading top level JSON arrays and take every element + // in such an array as a row + convertArray(factory, parser, st) + case (START_ARRAY, ArrayType(st, _)) => - convertList(factory, parser, st) + convertArray(factory, parser, st) case (START_OBJECT, ArrayType(st, _)) => // the business end of SPARK-3308: @@ -165,16 +170,16 @@ private[sql] object JacksonParser { builder.result() } - private def convertList( + private def convertArray( factory: JsonFactory, parser: JsonParser, - schema: DataType): Seq[Any] = { - val builder = Seq.newBuilder[Any] + elementType: DataType): ArrayData = { + val values = scala.collection.mutable.ArrayBuffer.empty[Any] while (nextUntil(parser, JsonToken.END_ARRAY)) { - builder += convertField(factory, parser, schema) + values += convertField(factory, parser, elementType) } - builder.result() + new GenericArrayData(values.toArray) } private def parseJson( @@ -201,12 +206,15 @@ private[sql] object JacksonParser { val parser = factory.createParser(record) parser.nextToken() - // to support both object and arrays (see SPARK-3308) we'll start - // by converting the StructType schema to an ArrayType and let - // convertField wrap an object into a single value array when necessary. - convertField(factory, parser, ArrayType(schema)) match { + convertField(factory, parser, schema) match { case null => failedRecord(record) - case list: Seq[InternalRow @unchecked] => list + case row: InternalRow => row :: Nil + case array: ArrayData => + if (array.numElements() == 0) { + Nil + } else { + array.toArray().map(_.asInstanceOf[InternalRow]) + } case _ => sys.error( s"Failed to parse record $record. Please make sure that each line of the file " + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRowConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRowConverter.scala index e00bd90edb3dd..172db8362afb6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRowConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRowConverter.scala @@ -325,7 +325,7 @@ private[parquet] class CatalystRowConverter( override def getConverter(fieldIndex: Int): Converter = elementConverter - override def end(): Unit = updater.set(currentArray) + override def end(): Unit = updater.set(new GenericArrayData(currentArray.toArray)) // NOTE: We can't reuse the mutable `ArrayBuffer` here and must instantiate a new buffer for the // next value. `Row.copy()` only copies row cells, it doesn't do deep copy to objects stored diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala index ea51650fe9039..2332a36468dbc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.parquet import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.types.ArrayData // TODO Removes this while fixing SPARK-8848 private[sql] object CatalystConverter { @@ -32,7 +33,7 @@ private[sql] object CatalystConverter { val MAP_SCHEMA_NAME = "map" // TODO: consider using Array[T] for arrays to avoid boxing of primitive types - type ArrayScalaType[T] = Seq[T] + type ArrayScalaType[T] = ArrayData type StructScalaType[T] = InternalRow type MapScalaType[K, V] = Map[K, V] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala index cc6fa2b88663f..b4337a48dbd80 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala @@ -39,11 +39,10 @@ import org.apache.parquet.{Log => ParquetLog} import org.apache.spark.{Logging, Partition => SparkPartition, SparkException} import org.apache.spark.broadcast.Broadcast -import org.apache.spark.rdd.RDD +import org.apache.spark.rdd.{SqlNewHadoopPartition, SqlNewHadoopRDD, RDD} import org.apache.spark.rdd.RDD._ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.execution.{SqlNewHadoopPartition, SqlNewHadoopRDD} import org.apache.spark.sql.execution.datasources.PartitionSpec import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.{DataType, StructType} @@ -125,6 +124,9 @@ private[sql] class ParquetRelation( .map(_.toBoolean) .getOrElse(sqlContext.conf.getConf(SQLConf.PARQUET_SCHEMA_MERGING_ENABLED)) + private val mergeRespectSummaries = + sqlContext.conf.getConf(SQLConf.PARQUET_SCHEMA_RESPECT_SUMMARIES) + private val maybeMetastoreSchema = parameters .get(ParquetRelation.METASTORE_SCHEMA) .map(DataType.fromJson(_).asInstanceOf[StructType]) @@ -422,7 +424,21 @@ private[sql] class ParquetRelation( val filesToTouch = if (shouldMergeSchemas) { // Also includes summary files, 'cause there might be empty partition directories. - (metadataStatuses ++ commonMetadataStatuses ++ dataStatuses).toSeq + + // If mergeRespectSummaries config is true, we assume that all part-files are the same for + // their schema with summary files, so we ignore them when merging schema. + // If the config is disabled, which is the default setting, we merge all part-files. + // In this mode, we only need to merge schemas contained in all those summary files. + // You should enable this configuration only if you are very sure that for the parquet + // part-files to read there are corresponding summary files containing correct schema. + + val needMerged: Seq[FileStatus] = + if (mergeRespectSummaries) { + Seq() + } else { + dataStatuses + } + (metadataStatuses ++ commonMetadataStatuses ++ needMerged).toSeq } else { // Tries any "_common_metadata" first. Parquet files written by old versions or Parquet // don't have this. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala index 78ecfad1d57c6..ec8da38a3d427 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala @@ -146,15 +146,15 @@ private[parquet] class RowWriteSupport extends WriteSupport[InternalRow] with Lo array: CatalystConverter.ArrayScalaType[_]): Unit = { val elementType = schema.elementType writer.startGroup() - if (array.size > 0) { + if (array.numElements() > 0) { if (schema.containsNull) { writer.startField(CatalystConverter.ARRAY_CONTAINS_NULL_BAG_SCHEMA_NAME, 0) var i = 0 - while (i < array.size) { + while (i < array.numElements()) { writer.startGroup() - if (array(i) != null) { + if (!array.isNullAt(i)) { writer.startField(CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, 0) - writeValue(elementType, array(i)) + writeValue(elementType, array.get(i)) writer.endField(CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, 0) } writer.endGroup() @@ -164,8 +164,8 @@ private[parquet] class RowWriteSupport extends WriteSupport[InternalRow] with Lo } else { writer.startField(CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, 0) var i = 0 - while (i < array.size) { - writeValue(elementType, array(i)) + while (i < array.numElements()) { + writeValue(elementType, array.get(i)) i = i + 1 } writer.endField(CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, 0) @@ -293,8 +293,8 @@ private[parquet] class MutableRowWriteSupport extends RowWriteSupport { writer.addBinary(Binary.fromByteArray(record.getUTF8String(index).getBytes)) case BinaryType => writer.addBinary(Binary.fromByteArray(record.getBinary(index))) - case DecimalType.Fixed(precision, _) => - writeDecimal(record.getDecimal(index), precision) + case DecimalType.Fixed(precision, scale) => + writeDecimal(record.getDecimal(index, precision, scale), precision) case _ => sys.error(s"Unsupported datatype $ctype, cannot write to consumer") } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index 72c42f4fe376b..2c669bb59a0b5 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -30,7 +30,6 @@ import scala.collection.JavaConversions; import scala.collection.Seq; -import scala.collection.mutable.Buffer; import java.io.Serializable; import java.util.Arrays; @@ -168,10 +167,10 @@ public void testCreateDataFrameFromJavaBeans() { for (int i = 0; i < result.length(); i++) { Assert.assertEquals(bean.getB()[i], result.apply(i)); } - Buffer outputBuffer = (Buffer) first.getJavaMap(2).get("hello"); + Seq outputBuffer = (Seq) first.getJavaMap(2).get("hello"); Assert.assertArrayEquals( bean.getC().get("hello"), - Ints.toArray(JavaConversions.bufferAsJavaList(outputBuffer))); + Ints.toArray(JavaConversions.seqAsJavaList(outputBuffer))); Seq d = first.getAs(3); Assert.assertEquals(bean.getD().size(), d.length()); for (int i = 0; i < d.length(); i++) { @@ -227,4 +226,13 @@ public void testCovariance() { Double result = df.stat().cov("a", "b"); Assert.assertTrue(Math.abs(result) < 1e-6); } + + @Test + public void testSampleBy() { + DataFrame df = context.range(0, 100).select(col("id").mod(3).as("key")); + DataFrame sampled = df.stat().sampleBy("key", ImmutableMap.of(0, 0.1, 1, 0.2), 0L); + Row[] actual = sampled.groupBy("key").count().orderBy("key").collect(); + Row[] expected = new Row[] {RowFactory.create(0, 5), RowFactory.create(1, 8)}; + Assert.assertArrayEquals(expected, actual); + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index 1f9f7118c3f04..eb64684ae0fd9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -19,16 +19,19 @@ package org.apache.spark.sql import org.scalatest.Matchers._ -import org.apache.spark.sql.execution.Project +import org.apache.spark.sql.execution.{Project, TungstenProject} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ +import org.apache.spark.sql.test.SQLTestUtils -class ColumnExpressionSuite extends QueryTest { +class ColumnExpressionSuite extends QueryTest with SQLTestUtils { import org.apache.spark.sql.TestData._ private lazy val ctx = org.apache.spark.sql.test.TestSQLContext import ctx.implicits._ + override def sqlContext(): SQLContext = ctx + test("alias") { val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList") assert(df.select(df("a").as("b")).columns.head === "b") @@ -489,6 +492,18 @@ class ColumnExpressionSuite extends QueryTest { ) } + test("InputFileName") { + withTempPath { dir => + val data = sqlContext.sparkContext.parallelize(0 to 10).toDF("id") + data.write.parquet(dir.getCanonicalPath) + val answer = sqlContext.read.parquet(dir.getCanonicalPath).select(inputFileName()) + .head.getString(0) + assert(answer.contains(dir.getCanonicalPath)) + + checkAnswer(data.select(inputFileName()).limit(1), Row("")) + } + } + test("lift alias out of cast") { compareExpressions( col("1234").as("name").cast("int").expr, @@ -523,6 +538,7 @@ class ColumnExpressionSuite extends QueryTest { def checkNumProjects(df: DataFrame, expectedNumProjects: Int): Unit = { val projects = df.queryExecution.executedPlan.collect { case project: Project => project + case tungstenProject: TungstenProject => tungstenProject } assert(projects.size === expectedNumProjects) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index b26d3ab253a1d..228ece8065151 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql import org.apache.spark.sql.TestData._ import org.apache.spark.sql.functions._ -import org.apache.spark.sql.types.DecimalType +import org.apache.spark.sql.types.{BinaryType, DecimalType} class DataFrameAggregateSuite extends QueryTest { @@ -191,4 +191,13 @@ class DataFrameAggregateSuite extends QueryTest { Row(null)) } + test("aggregation can't work on binary type") { + val df = Seq(1, 1, 2, 2).map(i => Tuple1(i.toString)).toDF("c").select($"c" cast BinaryType) + intercept[AnalysisException] { + df.groupBy("c").agg(count("*")) + } + intercept[AnalysisException] { + df.distinct + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index 7ba4ba73e0cc9..07a675e64f527 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -21,9 +21,9 @@ import java.util.Random import org.scalatest.Matchers._ -import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.functions.col -class DataFrameStatSuite extends SparkFunSuite { +class DataFrameStatSuite extends QueryTest { private val sqlCtx = org.apache.spark.sql.test.TestSQLContext import sqlCtx.implicits._ @@ -130,4 +130,12 @@ class DataFrameStatSuite extends SparkFunSuite { val items2 = singleColResults.collect().head items2.getSeq[Double](0) should contain (-1.0) } + + test("sampleBy") { + val df = sqlCtx.range(0, 100).select((col("id") % 3).as("key")) + val sampled = df.stat.sampleBy("key", Map(0 -> 0.1, 1 -> 0.2), 0L) + checkAnswer( + sampled.groupBy("key").count().orderBy("key"), + Seq(Row(0, 5), Row(1, 8))) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index f67f2c60c0e16..97beae2f85c50 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -23,40 +23,38 @@ import scala.language.postfixOps import scala.util.Random import org.apache.spark.sql.catalyst.plans.logical.OneRowRelation +import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.functions._ +import org.apache.spark.sql.json.JSONRelation +import org.apache.spark.sql.parquet.ParquetRelation import org.apache.spark.sql.types._ import org.apache.spark.sql.test.{ExamplePointUDT, ExamplePoint, SQLTestUtils} class DataFrameSuite extends QueryTest with SQLTestUtils { import org.apache.spark.sql.TestData._ - lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ - - def sqlContext: SQLContext = ctx + lazy val sqlContext = org.apache.spark.sql.test.TestSQLContext + import sqlContext.implicits._ test("analysis error should be eagerly reported") { - val oldSetting = ctx.conf.dataFrameEagerAnalysis // Eager analysis. - ctx.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, true) - - intercept[Exception] { testData.select('nonExistentName) } - intercept[Exception] { - testData.groupBy('key).agg(Map("nonExistentName" -> "sum")) - } - intercept[Exception] { - testData.groupBy("nonExistentName").agg(Map("key" -> "sum")) - } - intercept[Exception] { - testData.groupBy($"abcd").agg(Map("key" -> "sum")) + withSQLConf(SQLConf.DATAFRAME_EAGER_ANALYSIS.key -> "true") { + intercept[Exception] { testData.select('nonExistentName) } + intercept[Exception] { + testData.groupBy('key).agg(Map("nonExistentName" -> "sum")) + } + intercept[Exception] { + testData.groupBy("nonExistentName").agg(Map("key" -> "sum")) + } + intercept[Exception] { + testData.groupBy($"abcd").agg(Map("key" -> "sum")) + } } // No more eager analysis once the flag is turned off - ctx.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, false) - testData.select('nonExistentName) - - // Set the flag back to original value before this test. - ctx.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, oldSetting) + withSQLConf(SQLConf.DATAFRAME_EAGER_ANALYSIS.key -> "false") { + testData.select('nonExistentName) + } } test("dataframe toString") { @@ -74,21 +72,18 @@ class DataFrameSuite extends QueryTest with SQLTestUtils { } test("invalid plan toString, debug mode") { - val oldSetting = ctx.conf.dataFrameEagerAnalysis - ctx.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, true) - // Turn on debug mode so we can see invalid query plans. import org.apache.spark.sql.execution.debug._ - ctx.debug() - val badPlan = testData.select('badColumn) + withSQLConf(SQLConf.DATAFRAME_EAGER_ANALYSIS.key -> "true") { + sqlContext.debug() - assert(badPlan.toString contains badPlan.queryExecution.toString, - "toString on bad query plans should include the query execution but was:\n" + - badPlan.toString) + val badPlan = testData.select('badColumn) - // Set the flag back to original value before this test. - ctx.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, oldSetting) + assert(badPlan.toString contains badPlan.queryExecution.toString, + "toString on bad query plans should include the query execution but was:\n" + + badPlan.toString) + } } test("access complex data") { @@ -104,8 +99,8 @@ class DataFrameSuite extends QueryTest with SQLTestUtils { } test("empty data frame") { - assert(ctx.emptyDataFrame.columns.toSeq === Seq.empty[String]) - assert(ctx.emptyDataFrame.count() === 0) + assert(sqlContext.emptyDataFrame.columns.toSeq === Seq.empty[String]) + assert(sqlContext.emptyDataFrame.count() === 0) } test("head and take") { @@ -341,7 +336,7 @@ class DataFrameSuite extends QueryTest with SQLTestUtils { } test("replace column using withColumn") { - val df2 = ctx.sparkContext.parallelize(Array(1, 2, 3)).toDF("x") + val df2 = sqlContext.sparkContext.parallelize(Array(1, 2, 3)).toDF("x") val df3 = df2.withColumn("x", df2("x") + 1) checkAnswer( df3.select("x"), @@ -422,7 +417,7 @@ class DataFrameSuite extends QueryTest with SQLTestUtils { test("randomSplit") { val n = 600 - val data = ctx.sparkContext.parallelize(1 to n, 2).toDF("id") + val data = sqlContext.sparkContext.parallelize(1 to n, 2).toDF("id") for (seed <- 1 to 5) { val splits = data.randomSplit(Array[Double](1, 2, 3), seed) assert(splits.length == 3, "wrong number of splits") @@ -491,6 +486,23 @@ class DataFrameSuite extends QueryTest with SQLTestUtils { checkAnswer(df.select(df("key")), testData.select('key).collect().toSeq) } + test("inputFiles") { + val fakeRelation1 = new ParquetRelation(Array("/my/path", "/my/other/path"), + Some(testData.schema), None, Map.empty)(sqlContext) + val df1 = DataFrame(sqlContext, LogicalRelation(fakeRelation1)) + assert(df1.inputFiles.toSet == fakeRelation1.paths.toSet) + + val fakeRelation2 = new JSONRelation("/json/path", 1, Some(testData.schema), sqlContext) + val df2 = DataFrame(sqlContext, LogicalRelation(fakeRelation2)) + assert(df2.inputFiles.toSet == fakeRelation2.path.toSet) + + val unionDF = df1.unionAll(df2) + assert(unionDF.inputFiles.toSet == fakeRelation1.paths.toSet ++ fakeRelation2.path) + + val filtered = df1.filter("false").unionAll(df2.intersect(df2)) + assert(filtered.inputFiles.toSet == fakeRelation1.paths.toSet ++ fakeRelation2.path) + } + ignore("show") { // This test case is intended ignored, but to make sure it compiles correctly testData.select($"*").show() @@ -499,7 +511,7 @@ class DataFrameSuite extends QueryTest with SQLTestUtils { test("showString: truncate = [true, false]") { val longString = Array.fill(21)("1").mkString - val df = ctx.sparkContext.parallelize(Seq("1", longString)).toDF() + val df = sqlContext.sparkContext.parallelize(Seq("1", longString)).toDF() val expectedAnswerForFalse = """+---------------------+ ||_1 | |+---------------------+ @@ -589,21 +601,17 @@ class DataFrameSuite extends QueryTest with SQLTestUtils { } test("createDataFrame(RDD[Row], StructType) should convert UDTs (SPARK-6672)") { - val rowRDD = ctx.sparkContext.parallelize(Seq(Row(new ExamplePoint(1.0, 2.0)))) + val rowRDD = sqlContext.sparkContext.parallelize(Seq(Row(new ExamplePoint(1.0, 2.0)))) val schema = StructType(Array(StructField("point", new ExamplePointUDT(), false))) - val df = ctx.createDataFrame(rowRDD, schema) + val df = sqlContext.createDataFrame(rowRDD, schema) df.rdd.collect() } - test("SPARK-6899") { - val originalValue = ctx.conf.codegenEnabled - ctx.setConf(SQLConf.CODEGEN_ENABLED, true) - try{ + test("SPARK-6899: type should match when using codegen") { + withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "true") { checkAnswer( decimalData.agg(avg('a)), Row(new java.math.BigDecimal(2.0))) - } finally { - ctx.setConf(SQLConf.CODEGEN_ENABLED, originalValue) } } @@ -615,14 +623,14 @@ class DataFrameSuite extends QueryTest with SQLTestUtils { } test("SPARK-7551: support backticks for DataFrame attribute resolution") { - val df = ctx.read.json(ctx.sparkContext.makeRDD( + val df = sqlContext.read.json(sqlContext.sparkContext.makeRDD( """{"a.b": {"c": {"d..e": {"f": 1}}}}""" :: Nil)) checkAnswer( df.select(df("`a.b`.c.`d..e`.`f`")), Row(1) ) - val df2 = ctx.read.json(ctx.sparkContext.makeRDD( + val df2 = sqlContext.read.json(sqlContext.sparkContext.makeRDD( """{"a b": {"c": {"d e": {"f": 1}}}}""" :: Nil)) checkAnswer( df2.select(df2("`a b`.c.d e.f")), @@ -642,7 +650,7 @@ class DataFrameSuite extends QueryTest with SQLTestUtils { } test("SPARK-7324 dropDuplicates") { - val testData = ctx.sparkContext.parallelize( + val testData = sqlContext.sparkContext.parallelize( (2, 1, 2) :: (1, 1, 1) :: (1, 2, 1) :: (2, 1, 2) :: (2, 2, 2) :: (2, 2, 1) :: @@ -690,49 +698,49 @@ class DataFrameSuite extends QueryTest with SQLTestUtils { test("SPARK-7150 range api") { // numSlice is greater than length - val res1 = ctx.range(0, 10, 1, 15).select("id") + val res1 = sqlContext.range(0, 10, 1, 15).select("id") assert(res1.count == 10) assert(res1.agg(sum("id")).as("sumid").collect() === Seq(Row(45))) - val res2 = ctx.range(3, 15, 3, 2).select("id") + val res2 = sqlContext.range(3, 15, 3, 2).select("id") assert(res2.count == 4) assert(res2.agg(sum("id")).as("sumid").collect() === Seq(Row(30))) - val res3 = ctx.range(1, -2).select("id") + val res3 = sqlContext.range(1, -2).select("id") assert(res3.count == 0) // start is positive, end is negative, step is negative - val res4 = ctx.range(1, -2, -2, 6).select("id") + val res4 = sqlContext.range(1, -2, -2, 6).select("id") assert(res4.count == 2) assert(res4.agg(sum("id")).as("sumid").collect() === Seq(Row(0))) // start, end, step are negative - val res5 = ctx.range(-3, -8, -2, 1).select("id") + val res5 = sqlContext.range(-3, -8, -2, 1).select("id") assert(res5.count == 3) assert(res5.agg(sum("id")).as("sumid").collect() === Seq(Row(-15))) // start, end are negative, step is positive - val res6 = ctx.range(-8, -4, 2, 1).select("id") + val res6 = sqlContext.range(-8, -4, 2, 1).select("id") assert(res6.count == 2) assert(res6.agg(sum("id")).as("sumid").collect() === Seq(Row(-14))) - val res7 = ctx.range(-10, -9, -20, 1).select("id") + val res7 = sqlContext.range(-10, -9, -20, 1).select("id") assert(res7.count == 0) - val res8 = ctx.range(Long.MinValue, Long.MaxValue, Long.MaxValue, 100).select("id") + val res8 = sqlContext.range(Long.MinValue, Long.MaxValue, Long.MaxValue, 100).select("id") assert(res8.count == 3) assert(res8.agg(sum("id")).as("sumid").collect() === Seq(Row(-3))) - val res9 = ctx.range(Long.MaxValue, Long.MinValue, Long.MinValue, 100).select("id") + val res9 = sqlContext.range(Long.MaxValue, Long.MinValue, Long.MinValue, 100).select("id") assert(res9.count == 2) assert(res9.agg(sum("id")).as("sumid").collect() === Seq(Row(Long.MaxValue - 1))) // only end provided as argument - val res10 = ctx.range(10).select("id") + val res10 = sqlContext.range(10).select("id") assert(res10.count == 10) assert(res10.agg(sum("id")).as("sumid").collect() === Seq(Row(45))) - val res11 = ctx.range(-1).select("id") + val res11 = sqlContext.range(-1).select("id") assert(res11.count == 0) } @@ -799,13 +807,13 @@ class DataFrameSuite extends QueryTest with SQLTestUtils { // pass case: parquet table (HadoopFsRelation) df.write.mode(SaveMode.Overwrite).parquet(tempParquetFile.getCanonicalPath) - val pdf = ctx.read.parquet(tempParquetFile.getCanonicalPath) + val pdf = sqlContext.read.parquet(tempParquetFile.getCanonicalPath) pdf.registerTempTable("parquet_base") insertion.write.insertInto("parquet_base") // pass case: json table (InsertableRelation) df.write.mode(SaveMode.Overwrite).json(tempJsonFile.getCanonicalPath) - val jdf = ctx.read.json(tempJsonFile.getCanonicalPath) + val jdf = sqlContext.read.json(tempJsonFile.getCanonicalPath) jdf.registerTempTable("json_base") insertion.write.mode(SaveMode.Overwrite).insertInto("json_base") @@ -825,11 +833,54 @@ class DataFrameSuite extends QueryTest with SQLTestUtils { assert(e2.getMessage.contains("Inserting into an RDD-based table is not allowed.")) // error case: insert into an OneRowRelation - new DataFrame(ctx, OneRowRelation).registerTempTable("one_row") + new DataFrame(sqlContext, OneRowRelation).registerTempTable("one_row") val e3 = intercept[AnalysisException] { insertion.write.insertInto("one_row") } assert(e3.getMessage.contains("Inserting into an RDD-based table is not allowed.")) } } + + test("SPARK-8608: call `show` on local DataFrame with random columns should return same value") { + // Make sure we can pass this test for both codegen mode and interpreted mode. + withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "true") { + val df = testData.select(rand(33)) + assert(df.showString(5) == df.showString(5)) + } + + withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "false") { + val df = testData.select(rand(33)) + assert(df.showString(5) == df.showString(5)) + } + + // We will reuse the same Expression object for LocalRelation. + val df = (1 to 10).map(Tuple1.apply).toDF().select(rand(33)) + assert(df.showString(5) == df.showString(5)) + } + + test("SPARK-8609: local DataFrame with random columns should return same value after sort") { + // Make sure we can pass this test for both codegen mode and interpreted mode. + withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "true") { + checkAnswer(testData.sort(rand(33)), testData.sort(rand(33))) + } + + withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "false") { + checkAnswer(testData.sort(rand(33)), testData.sort(rand(33))) + } + + // We will reuse the same Expression object for LocalRelation. + val df = (1 to 10).map(Tuple1.apply).toDF() + checkAnswer(df.sort(rand(33)), df.sort(rand(33))) + } + + test("SPARK-9083: sort with non-deterministic expressions") { + import org.apache.spark.util.random.XORShiftRandom + + val seed = 33 + val df = (1 to 100).map(Tuple1.apply).toDF("i") + val random = new XORShiftRandom(seed) + val expected = (1 to 100).map(_ -> random.nextDouble()).sortBy(_._2).map(_._1) + val actual = df.sort(rand(seed)).collect().map(_.getInt(0)) + assert(expected === actual) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala new file mode 100644 index 0000000000000..bf8ef9a97bc60 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.types._ + +/** + * An end-to-end test suite specifically for testing Tungsten (Unsafe/CodeGen) mode. + * + * This is here for now so I can make sure Tungsten project is tested without refactoring existing + * end-to-end test infra. In the long run this should just go away. + */ +class DataFrameTungstenSuite extends QueryTest with SQLTestUtils { + + override lazy val sqlContext: SQLContext = org.apache.spark.sql.test.TestSQLContext + import sqlContext.implicits._ + + test("test simple types") { + withSQLConf(SQLConf.UNSAFE_ENABLED.key -> "true") { + val df = sqlContext.sparkContext.parallelize(Seq((1, 2))).toDF("a", "b") + assert(df.select(struct("a", "b")).first().getStruct(0) === Row(1, 2)) + } + } + + test("test struct type") { + withSQLConf(SQLConf.UNSAFE_ENABLED.key -> "true") { + val struct = Row(1, 2L, 3.0F, 3.0) + val data = sqlContext.sparkContext.parallelize(Seq(Row(1, struct))) + + val schema = new StructType() + .add("a", IntegerType) + .add("b", + new StructType() + .add("b1", IntegerType) + .add("b2", LongType) + .add("b3", FloatType) + .add("b4", DoubleType)) + + val df = sqlContext.createDataFrame(data, schema) + assert(df.select("b").first() === Row(struct)) + } + } + + test("test nested struct type") { + withSQLConf(SQLConf.UNSAFE_ENABLED.key -> "true") { + val innerStruct = Row(1, "abcd") + val outerStruct = Row(1, 2L, 3.0F, 3.0, innerStruct, "efg") + val data = sqlContext.sparkContext.parallelize(Seq(Row(1, outerStruct))) + + val schema = new StructType() + .add("a", IntegerType) + .add("b", + new StructType() + .add("b1", IntegerType) + .add("b2", LongType) + .add("b3", FloatType) + .add("b4", DoubleType) + .add("b5", new StructType() + .add("b5a", IntegerType) + .add("b5b", StringType)) + .add("b6", StringType)) + + val df = sqlContext.createDataFrame(data, schema) + assert(df.select("b").first() === Row(outerStruct)) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala index 9e80ae86920d9..8c596fad74ee4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala @@ -20,13 +20,36 @@ package org.apache.spark.sql import java.sql.{Timestamp, Date} import java.text.SimpleDateFormat +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.functions._ +import org.apache.spark.unsafe.types.CalendarInterval class DateFunctionsSuite extends QueryTest { private lazy val ctx = org.apache.spark.sql.test.TestSQLContext import ctx.implicits._ + test("function current_date") { + val df1 = Seq((1, 2), (3, 1)).toDF("a", "b") + val d0 = DateTimeUtils.millisToDays(System.currentTimeMillis()) + val d1 = DateTimeUtils.fromJavaDate(df1.select(current_date()).collect().head.getDate(0)) + val d2 = DateTimeUtils.fromJavaDate( + ctx.sql("""SELECT CURRENT_DATE()""").collect().head.getDate(0)) + val d3 = DateTimeUtils.millisToDays(System.currentTimeMillis()) + assert(d0 <= d1 && d1 <= d2 && d2 <= d3 && d3 - d0 <= 1) + } + + // This is a bad test. SPARK-9196 will fix it and re-enable it. + ignore("function current_timestamp") { + val df1 = Seq((1, 2), (3, 1)).toDF("a", "b") + checkAnswer(df1.select(countDistinct(current_timestamp())), Row(1)) + // Execution in one query should return the same value + checkAnswer(ctx.sql("""SELECT CURRENT_TIMESTAMP() = CURRENT_TIMESTAMP()"""), + Row(true)) + assert(math.abs(ctx.sql("""SELECT CURRENT_TIMESTAMP()""").collect().head.getTimestamp( + 0).getTime - System.currentTimeMillis()) < 5000) + } + val sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss") val sdfDate = new SimpleDateFormat("yyyy-MM-dd") val d = new Date(sdf.parse("2015-04-08 13:10:15").getTime) @@ -184,4 +207,242 @@ class DateFunctionsSuite extends QueryTest { Row(15, 15, 15)) } + test("function date_add") { + val st1 = "2015-06-01 12:34:56" + val st2 = "2015-06-02 12:34:56" + val t1 = Timestamp.valueOf(st1) + val t2 = Timestamp.valueOf(st2) + val s1 = "2015-06-01" + val s2 = "2015-06-02" + val d1 = Date.valueOf(s1) + val d2 = Date.valueOf(s2) + val df = Seq((t1, d1, s1, st1), (t2, d2, s2, st2)).toDF("t", "d", "s", "ss") + checkAnswer( + df.select(date_add(col("d"), 1)), + Seq(Row(Date.valueOf("2015-06-02")), Row(Date.valueOf("2015-06-03")))) + checkAnswer( + df.select(date_add(col("t"), 3)), + Seq(Row(Date.valueOf("2015-06-04")), Row(Date.valueOf("2015-06-05")))) + checkAnswer( + df.select(date_add(col("s"), 5)), + Seq(Row(Date.valueOf("2015-06-06")), Row(Date.valueOf("2015-06-07")))) + checkAnswer( + df.select(date_add(col("ss"), 7)), + Seq(Row(Date.valueOf("2015-06-08")), Row(Date.valueOf("2015-06-09")))) + + checkAnswer(df.selectExpr("DATE_ADD(null, 1)"), Seq(Row(null), Row(null))) + checkAnswer( + df.selectExpr("""DATE_ADD(d, 1)"""), + Seq(Row(Date.valueOf("2015-06-02")), Row(Date.valueOf("2015-06-03")))) + } + + test("function date_sub") { + val st1 = "2015-06-01 12:34:56" + val st2 = "2015-06-02 12:34:56" + val t1 = Timestamp.valueOf(st1) + val t2 = Timestamp.valueOf(st2) + val s1 = "2015-06-01" + val s2 = "2015-06-02" + val d1 = Date.valueOf(s1) + val d2 = Date.valueOf(s2) + val df = Seq((t1, d1, s1, st1), (t2, d2, s2, st2)).toDF("t", "d", "s", "ss") + checkAnswer( + df.select(date_sub(col("d"), 1)), + Seq(Row(Date.valueOf("2015-05-31")), Row(Date.valueOf("2015-06-01")))) + checkAnswer( + df.select(date_sub(col("t"), 1)), + Seq(Row(Date.valueOf("2015-05-31")), Row(Date.valueOf("2015-06-01")))) + checkAnswer( + df.select(date_sub(col("s"), 1)), + Seq(Row(Date.valueOf("2015-05-31")), Row(Date.valueOf("2015-06-01")))) + checkAnswer( + df.select(date_sub(col("ss"), 1)), + Seq(Row(Date.valueOf("2015-05-31")), Row(Date.valueOf("2015-06-01")))) + checkAnswer( + df.select(date_sub(lit(null), 1)).limit(1), Row(null)) + + checkAnswer(df.selectExpr("""DATE_SUB(d, null)"""), Seq(Row(null), Row(null))) + checkAnswer( + df.selectExpr("""DATE_SUB(d, 1)"""), + Seq(Row(Date.valueOf("2015-05-31")), Row(Date.valueOf("2015-06-01")))) + } + + test("time_add") { + val t1 = Timestamp.valueOf("2015-07-31 23:59:59") + val t2 = Timestamp.valueOf("2015-12-31 00:00:00") + val d1 = Date.valueOf("2015-07-31") + val d2 = Date.valueOf("2015-12-31") + val i = new CalendarInterval(2, 2000000L) + val df = Seq((1, t1, d1), (3, t2, d2)).toDF("n", "t", "d") + checkAnswer( + df.selectExpr(s"d + $i"), + Seq(Row(Date.valueOf("2015-09-30")), Row(Date.valueOf("2016-02-29")))) + checkAnswer( + df.selectExpr(s"t + $i"), + Seq(Row(Timestamp.valueOf("2015-10-01 00:00:01")), + Row(Timestamp.valueOf("2016-02-29 00:00:02")))) + } + + test("time_sub") { + val t1 = Timestamp.valueOf("2015-10-01 00:00:01") + val t2 = Timestamp.valueOf("2016-02-29 00:00:02") + val d1 = Date.valueOf("2015-09-30") + val d2 = Date.valueOf("2016-02-29") + val i = new CalendarInterval(2, 2000000L) + val df = Seq((1, t1, d1), (3, t2, d2)).toDF("n", "t", "d") + checkAnswer( + df.selectExpr(s"d - $i"), + Seq(Row(Date.valueOf("2015-07-30")), Row(Date.valueOf("2015-12-30")))) + checkAnswer( + df.selectExpr(s"t - $i"), + Seq(Row(Timestamp.valueOf("2015-07-31 23:59:59")), + Row(Timestamp.valueOf("2015-12-31 00:00:00")))) + } + + test("function add_months") { + val d1 = Date.valueOf("2015-08-31") + val d2 = Date.valueOf("2015-02-28") + val df = Seq((1, d1), (2, d2)).toDF("n", "d") + checkAnswer( + df.select(add_months(col("d"), 1)), + Seq(Row(Date.valueOf("2015-09-30")), Row(Date.valueOf("2015-03-31")))) + checkAnswer( + df.selectExpr("add_months(d, -1)"), + Seq(Row(Date.valueOf("2015-07-31")), Row(Date.valueOf("2015-01-31")))) + } + + test("function months_between") { + val d1 = Date.valueOf("2015-07-31") + val d2 = Date.valueOf("2015-02-16") + val t1 = Timestamp.valueOf("2014-09-30 23:30:00") + val t2 = Timestamp.valueOf("2015-09-16 12:00:00") + val s1 = "2014-09-15 11:30:00" + val s2 = "2015-10-01 00:00:00" + val df = Seq((t1, d1, s1), (t2, d2, s2)).toDF("t", "d", "s") + checkAnswer(df.select(months_between(col("t"), col("d"))), Seq(Row(-10.0), Row(7.0))) + checkAnswer(df.selectExpr("months_between(t, s)"), Seq(Row(0.5), Row(-0.5))) + } + + test("function last_day") { + val df1 = Seq((1, "2015-07-23"), (2, "2015-07-24")).toDF("i", "d") + val df2 = Seq((1, "2015-07-23 00:11:22"), (2, "2015-07-24 11:22:33")).toDF("i", "t") + checkAnswer( + df1.select(last_day(col("d"))), + Seq(Row(Date.valueOf("2015-07-31")), Row(Date.valueOf("2015-07-31")))) + checkAnswer( + df2.select(last_day(col("t"))), + Seq(Row(Date.valueOf("2015-07-31")), Row(Date.valueOf("2015-07-31")))) + } + + test("function next_day") { + val df1 = Seq(("mon", "2015-07-23"), ("tuesday", "2015-07-20")).toDF("dow", "d") + val df2 = Seq(("th", "2015-07-23 00:11:22"), ("xx", "2015-07-24 11:22:33")).toDF("dow", "t") + checkAnswer( + df1.select(next_day(col("d"), "MONDAY")), + Seq(Row(Date.valueOf("2015-07-27")), Row(Date.valueOf("2015-07-27")))) + checkAnswer( + df2.select(next_day(col("t"), "th")), + Seq(Row(Date.valueOf("2015-07-30")), Row(Date.valueOf("2015-07-30")))) + } + + test("function to_date") { + val d1 = Date.valueOf("2015-07-22") + val d2 = Date.valueOf("2015-07-01") + val t1 = Timestamp.valueOf("2015-07-22 10:00:00") + val t2 = Timestamp.valueOf("2014-12-31 23:59:59") + val s1 = "2015-07-22 10:00:00" + val s2 = "2014-12-31" + val df = Seq((d1, t1, s1), (d2, t2, s2)).toDF("d", "t", "s") + + checkAnswer( + df.select(to_date(col("t"))), + Seq(Row(Date.valueOf("2015-07-22")), Row(Date.valueOf("2014-12-31")))) + checkAnswer( + df.select(to_date(col("d"))), + Seq(Row(Date.valueOf("2015-07-22")), Row(Date.valueOf("2015-07-01")))) + checkAnswer( + df.select(to_date(col("s"))), + Seq(Row(Date.valueOf("2015-07-22")), Row(Date.valueOf("2014-12-31")))) + + checkAnswer( + df.selectExpr("to_date(t)"), + Seq(Row(Date.valueOf("2015-07-22")), Row(Date.valueOf("2014-12-31")))) + checkAnswer( + df.selectExpr("to_date(d)"), + Seq(Row(Date.valueOf("2015-07-22")), Row(Date.valueOf("2015-07-01")))) + checkAnswer( + df.selectExpr("to_date(s)"), + Seq(Row(Date.valueOf("2015-07-22")), Row(Date.valueOf("2014-12-31")))) + } + + test("function trunc") { + val df = Seq( + (1, Timestamp.valueOf("2015-07-22 10:00:00")), + (2, Timestamp.valueOf("2014-12-31 00:00:00"))).toDF("i", "t") + + checkAnswer( + df.select(trunc(col("t"), "YY")), + Seq(Row(Date.valueOf("2015-01-01")), Row(Date.valueOf("2014-01-01")))) + + checkAnswer( + df.selectExpr("trunc(t, 'Month')"), + Seq(Row(Date.valueOf("2015-07-01")), Row(Date.valueOf("2014-12-01")))) + } + + test("from_unixtime") { + val sdf1 = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss") + val fmt2 = "yyyy-MM-dd HH:mm:ss.SSS" + val sdf2 = new SimpleDateFormat(fmt2) + val fmt3 = "yy-MM-dd HH-mm-ss" + val sdf3 = new SimpleDateFormat(fmt3) + val df = Seq((1000, "yyyy-MM-dd HH:mm:ss.SSS"), (-1000, "yy-MM-dd HH-mm-ss")).toDF("a", "b") + checkAnswer( + df.select(from_unixtime(col("a"))), + Seq(Row(sdf1.format(new Timestamp(1000000))), Row(sdf1.format(new Timestamp(-1000000))))) + checkAnswer( + df.select(from_unixtime(col("a"), fmt2)), + Seq(Row(sdf2.format(new Timestamp(1000000))), Row(sdf2.format(new Timestamp(-1000000))))) + checkAnswer( + df.select(from_unixtime(col("a"), fmt3)), + Seq(Row(sdf3.format(new Timestamp(1000000))), Row(sdf3.format(new Timestamp(-1000000))))) + checkAnswer( + df.selectExpr("from_unixtime(a)"), + Seq(Row(sdf1.format(new Timestamp(1000000))), Row(sdf1.format(new Timestamp(-1000000))))) + checkAnswer( + df.selectExpr(s"from_unixtime(a, '$fmt2')"), + Seq(Row(sdf2.format(new Timestamp(1000000))), Row(sdf2.format(new Timestamp(-1000000))))) + checkAnswer( + df.selectExpr(s"from_unixtime(a, '$fmt3')"), + Seq(Row(sdf3.format(new Timestamp(1000000))), Row(sdf3.format(new Timestamp(-1000000))))) + } + + test("unix_timestamp") { + val date1 = Date.valueOf("2015-07-24") + val date2 = Date.valueOf("2015-07-25") + val ts1 = Timestamp.valueOf("2015-07-24 10:00:00.3") + val ts2 = Timestamp.valueOf("2015-07-25 02:02:02.2") + val s1 = "2015/07/24 10:00:00.5" + val s2 = "2015/07/25 02:02:02.6" + val ss1 = "2015-07-24 10:00:00" + val ss2 = "2015-07-25 02:02:02" + val fmt = "yyyy/MM/dd HH:mm:ss.S" + val df = Seq((date1, ts1, s1, ss1), (date2, ts2, s2, ss2)).toDF("d", "ts", "s", "ss") + checkAnswer(df.select(unix_timestamp(col("ts"))), Seq( + Row(ts1.getTime / 1000L), Row(ts2.getTime / 1000L))) + checkAnswer(df.select(unix_timestamp(col("ss"))), Seq( + Row(ts1.getTime / 1000L), Row(ts2.getTime / 1000L))) + checkAnswer(df.select(unix_timestamp(col("d"), fmt)), Seq( + Row(date1.getTime / 1000L), Row(date2.getTime / 1000L))) + checkAnswer(df.select(unix_timestamp(col("s"), fmt)), Seq( + Row(ts1.getTime / 1000L), Row(ts2.getTime / 1000L))) + checkAnswer(df.selectExpr("unix_timestamp(ts)"), Seq( + Row(ts1.getTime / 1000L), Row(ts2.getTime / 1000L))) + checkAnswer(df.selectExpr("unix_timestamp(ss)"), Seq( + Row(ts1.getTime / 1000L), Row(ts2.getTime / 1000L))) + checkAnswer(df.selectExpr(s"unix_timestamp(d, '$fmt')"), Seq( + Row(date1.getTime / 1000L), Row(date2.getTime / 1000L))) + checkAnswer(df.selectExpr(s"unix_timestamp(s, '$fmt')"), Seq( + Row(ts1.getTime / 1000L), Row(ts2.getTime / 1000L))) + } + } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatetimeExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatetimeExpressionsSuite.scala deleted file mode 100644 index 44b915304533c..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatetimeExpressionsSuite.scala +++ /dev/null @@ -1,48 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql - -import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.functions._ - -class DatetimeExpressionsSuite extends QueryTest { - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - - import ctx.implicits._ - - lazy val df1 = Seq((1, 2), (3, 1)).toDF("a", "b") - - test("function current_date") { - val d0 = DateTimeUtils.millisToDays(System.currentTimeMillis()) - val d1 = DateTimeUtils.fromJavaDate(df1.select(current_date()).collect().head.getDate(0)) - val d2 = DateTimeUtils.fromJavaDate( - ctx.sql("""SELECT CURRENT_DATE()""").collect().head.getDate(0)) - val d3 = DateTimeUtils.millisToDays(System.currentTimeMillis()) - assert(d0 <= d1 && d1 <= d2 && d2 <= d3 && d3 - d0 <= 1) - } - - test("function current_timestamp") { - checkAnswer(df1.select(countDistinct(current_timestamp())), Row(1)) - // Execution in one query should return the same value - checkAnswer(ctx.sql("""SELECT CURRENT_TIMESTAMP() = CURRENT_TIMESTAMP()"""), - Row(true)) - assert(math.abs(ctx.sql("""SELECT CURRENT_TIMESTAMP()""").collect().head.getTimestamp( - 0).getTime - System.currentTimeMillis()) < 5000) - } - -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index dfb2a7e099748..27c08f64649ee 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -22,6 +22,7 @@ import org.scalatest.BeforeAndAfterEach import org.apache.spark.sql.TestData._ import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.execution.joins._ +import org.apache.spark.sql.types.BinaryType class JoinSuite extends QueryTest with BeforeAndAfterEach { @@ -79,9 +80,9 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { ("SELECT * FROM testData FULL OUTER JOIN testData2 WHERE key = 2", classOf[CartesianProduct]), ("SELECT * FROM testData JOIN testData2 WHERE key > a", classOf[CartesianProduct]), ("SELECT * FROM testData FULL OUTER JOIN testData2 WHERE key > a", classOf[CartesianProduct]), - ("SELECT * FROM testData JOIN testData2 ON key = a", classOf[ShuffledHashJoin]), - ("SELECT * FROM testData JOIN testData2 ON key = a and key = 2", classOf[ShuffledHashJoin]), - ("SELECT * FROM testData JOIN testData2 ON key = a where key = 2", classOf[ShuffledHashJoin]), + ("SELECT * FROM testData JOIN testData2 ON key = a", classOf[SortMergeJoin]), + ("SELECT * FROM testData JOIN testData2 ON key = a and key = 2", classOf[SortMergeJoin]), + ("SELECT * FROM testData JOIN testData2 ON key = a where key = 2", classOf[SortMergeJoin]), ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[ShuffledHashOuterJoin]), ("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2", classOf[ShuffledHashOuterJoin]), @@ -489,4 +490,12 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { Row(3, 2) :: Nil) } + + test("Join can't work on binary type") { + val left = Seq(1, 1, 2, 2).map(i => Tuple1(i.toString)).toDF("c").select($"c" cast BinaryType) + val right = Seq(1, 1, 2, 2).map(i => Tuple1(i.toString)).toDF("d").select($"d" cast BinaryType) + intercept[AnalysisException] { + left.join(right, ($"left.N" === $"right.N"), "full") + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala index 21256704a5b16..8cf2ef5957d8d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala @@ -216,7 +216,8 @@ class MathExpressionsSuite extends QueryTest { checkAnswer( ctx.sql(s"SELECT round($pi, -3), round($pi, -2), round($pi, -1), " + s"round($pi, 0), round($pi, 1), round($pi, 2), round($pi, 3)"), - Seq(Row(0.0, 0.0, 0.0, 3.0, 3.1, 3.14, 3.142)) + Seq(Row(BigDecimal("0E3"), BigDecimal("0E2"), BigDecimal("0E1"), BigDecimal(3), + BigDecimal("3.1"), BigDecimal("3.14"), BigDecimal("3.142"))) ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 358e319476e83..535011fe3db5b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -227,6 +227,37 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { Seq(Row("1"), Row("2"))) } + test("SPARK-8828 sum should return null if all input values are null") { + withSQLConf(SQLConf.USE_SQL_AGGREGATE2.key -> "true") { + withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "true") { + checkAnswer( + sql("select sum(a), avg(a) from allNulls"), + Seq(Row(null, null)) + ) + } + withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "false") { + checkAnswer( + sql("select sum(a), avg(a) from allNulls"), + Seq(Row(null, null)) + ) + } + } + withSQLConf(SQLConf.USE_SQL_AGGREGATE2.key -> "false") { + withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "true") { + checkAnswer( + sql("select sum(a), avg(a) from allNulls"), + Seq(Row(null, null)) + ) + } + withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "false") { + checkAnswer( + sql("select sum(a), avg(a) from allNulls"), + Seq(Row(null, null)) + ) + } + } + } + test("aggregation with codegen") { val originalValue = sqlContext.conf.codegenEnabled sqlContext.setConf(SQLConf.CODEGEN_ENABLED, true) @@ -337,7 +368,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { Row(1)) checkAnswer( sql("SELECT COALESCE(null, 1, 1.5)"), - Row(1.toDouble)) + Row(BigDecimal(1))) checkAnswer( sql("SELECT COALESCE(null, null, null)"), Row(null)) @@ -1203,19 +1234,19 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { test("Floating point number format") { checkAnswer( - sql("SELECT 0.3"), Row(0.3) + sql("SELECT 0.3"), Row(BigDecimal(0.3).underlying()) ) checkAnswer( - sql("SELECT -0.8"), Row(-0.8) + sql("SELECT -0.8"), Row(BigDecimal(-0.8).underlying()) ) checkAnswer( - sql("SELECT .5"), Row(0.5) + sql("SELECT .5"), Row(BigDecimal(0.5)) ) checkAnswer( - sql("SELECT -.18"), Row(-0.18) + sql("SELECT -.18"), Row(BigDecimal(-0.18)) ) } @@ -1248,11 +1279,11 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { ) checkAnswer( - sql("SELECT -5.2"), Row(-5.2) + sql("SELECT -5.2"), Row(BigDecimal(-5.2)) ) checkAnswer( - sql("SELECT +6.8"), Row(6.8) + sql("SELECT +6.8"), Row(BigDecimal(6.8)) ) checkAnswer( @@ -1546,10 +1577,10 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { } test("SPARK-8753: add interval type") { - import org.apache.spark.unsafe.types.Interval + import org.apache.spark.unsafe.types.CalendarInterval val df = sql("select interval 3 years -3 month 7 week 123 microseconds") - checkAnswer(df, Row(new Interval(12 * 3 - 3, 7L * 1000 * 1000 * 3600 * 24 * 7 + 123 ))) + checkAnswer(df, Row(new CalendarInterval(12 * 3 - 3, 7L * 1000 * 1000 * 3600 * 24 * 7 + 123 ))) withTempPath(f => { // Currently we don't yet support saving out values of interval data type. val e = intercept[AnalysisException] { @@ -1571,20 +1602,20 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { } test("SPARK-8945: add and subtract expressions for interval type") { - import org.apache.spark.unsafe.types.Interval - import org.apache.spark.unsafe.types.Interval.MICROS_PER_WEEK + import org.apache.spark.unsafe.types.CalendarInterval + import org.apache.spark.unsafe.types.CalendarInterval.MICROS_PER_WEEK val df = sql("select interval 3 years -3 month 7 week 123 microseconds as i") - checkAnswer(df, Row(new Interval(12 * 3 - 3, 7L * MICROS_PER_WEEK + 123))) + checkAnswer(df, Row(new CalendarInterval(12 * 3 - 3, 7L * MICROS_PER_WEEK + 123))) - checkAnswer(df.select(df("i") + new Interval(2, 123)), - Row(new Interval(12 * 3 - 3 + 2, 7L * MICROS_PER_WEEK + 123 + 123))) + checkAnswer(df.select(df("i") + new CalendarInterval(2, 123)), + Row(new CalendarInterval(12 * 3 - 3 + 2, 7L * MICROS_PER_WEEK + 123 + 123))) - checkAnswer(df.select(df("i") - new Interval(2, 123)), - Row(new Interval(12 * 3 - 3 - 2, 7L * MICROS_PER_WEEK + 123 - 123))) + checkAnswer(df.select(df("i") - new CalendarInterval(2, 123)), + Row(new CalendarInterval(12 * 3 - 3 - 2, 7L * MICROS_PER_WEEK + 123 - 123))) // unary minus checkAnswer(df.select(-df("i")), - Row(new Interval(-(12 * 3 - 3), -(7L * MICROS_PER_WEEK + 123)))) + Row(new CalendarInterval(-(12 * 3 - 3), -(7L * MICROS_PER_WEEK + 123)))) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala index 0f9c986f649a1..8e0ea76d15881 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala @@ -57,19 +57,27 @@ class StringFunctionsSuite extends QueryTest { } test("string regex_replace / regex_extract") { - val df = Seq(("100-200", "")).toDF("a", "b") + val df = Seq( + ("100-200", "(\\d+)-(\\d+)", "300"), + ("100-200", "(\\d+)-(\\d+)", "400"), + ("100-200", "(\\d+)", "400")).toDF("a", "b", "c") checkAnswer( df.select( regexp_replace($"a", "(\\d+)", "num"), regexp_extract($"a", "(\\d+)-(\\d+)", 1)), - Row("num-num", "100")) - - checkAnswer( - df.selectExpr( - "regexp_replace(a, '(\\d+)', 'num')", - "regexp_extract(a, '(\\d+)-(\\d+)', 2)"), - Row("num-num", "200")) + Row("num-num", "100") :: Row("num-num", "100") :: Row("num-num", "100") :: Nil) + + // for testing the mutable state of the expression in code gen. + // This is a hack way to enable the codegen, thus the codegen is enable by default, + // it will still use the interpretProjection if projection followed by a LocalRelation, + // hence we add a filter operator. + // See the optimizer rule `ConvertToLocalRelation` + checkAnswer( + df.filter("isnotnull(a)").selectExpr( + "regexp_replace(a, b, c)", + "regexp_extract(a, b, 1)"), + Row("300", "100") :: Row("400", "100") :: Row("400-400", "100") :: Nil) } test("string ascii function") { @@ -290,5 +298,15 @@ class StringFunctionsSuite extends QueryTest { df.selectExpr("format_number(e, g)"), // decimal type of the 2nd argument is unacceptable Row("5.0000")) } + + // for testing the mutable state of the expression in code gen. + // This is a hack way to enable the codegen, thus the codegen is enable by default, + // it will still use the interpretProjection if projection follows by a LocalRelation, + // hence we add a filter operator. + // See the optimizer rule `ConvertToLocalRelation` + val df2 = Seq((5L, 4), (4L, 3), (3L, 2)).toDF("a", "b") + checkAnswer( + df2.filter("b>0").selectExpr("format_number(a, b)"), + Row("5.0000") :: Row("4.000") :: Row("3.00") :: Nil) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala index 207d7a352c7b3..e340f54850bcc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql -import java.sql.Timestamp - import org.apache.spark.sql.test.TestSQLContext.implicits._ import org.apache.spark.sql.test._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index c1516b450cbd4..183dc3407b3ab 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -17,14 +17,17 @@ package org.apache.spark.sql +import org.apache.spark.sql.test.SQLTestUtils case class FunctionResult(f1: String, f2: String) -class UDFSuite extends QueryTest { +class UDFSuite extends QueryTest with SQLTestUtils { private lazy val ctx = org.apache.spark.sql.test.TestSQLContext import ctx.implicits._ + override def sqlContext(): SQLContext = ctx + test("built-in fixed arity expressions") { val df = ctx.emptyDataFrame df.selectExpr("rand()", "randn()", "rand(5)", "randn(50)") @@ -51,6 +54,25 @@ class UDFSuite extends QueryTest { df.selectExpr("count(distinct a)") } + test("SPARK-8003 spark_partition_id") { + val df = Seq((1, "Tearing down the walls that divide us")).toDF("id", "saying") + df.registerTempTable("tmp_table") + checkAnswer(ctx.sql("select spark_partition_id() from tmp_table").toDF(), Row(0)) + ctx.dropTempTable("tmp_table") + } + + test("SPARK-8005 input_file_name") { + withTempPath { dir => + val data = ctx.sparkContext.parallelize(0 to 10, 2).toDF("id") + data.write.parquet(dir.getCanonicalPath) + ctx.read.parquet(dir.getCanonicalPath).registerTempTable("test_table") + val answer = ctx.sql("select input_file_name() from test_table").head().getString(0) + assert(answer.contains(dir.getCanonicalPath)) + assert(ctx.sql("select input_file_name() from test_table").distinct().collect().length >= 2) + ctx.dropTempTable("test_table") + } + } + test("error reporting for incorrect number of arguments") { val df = ctx.emptyDataFrame val e = intercept[AnalysisException] { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala index ad3bb1744cb3c..e72a1bc6c4e20 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala @@ -22,7 +22,7 @@ import java.io.ByteArrayOutputStream import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{UnsafeRow, UnsafeProjection} -import org.apache.spark.sql.types.{DataType, IntegerType, StringType} +import org.apache.spark.sql.types._ import org.apache.spark.unsafe.PlatformDependent import org.apache.spark.unsafe.memory.MemoryAllocator import org.apache.spark.unsafe.types.UTF8String @@ -67,4 +67,19 @@ class UnsafeRowSuite extends SparkFunSuite { assert(bytesFromArrayBackedRow === bytesFromOffheapRow) } + + test("calling getDouble() and getFloat() on null columns") { + val row = InternalRow.apply(null, null) + val unsafeRow = UnsafeProjection.create(Array[DataType](FloatType, DoubleType)).apply(row) + assert(unsafeRow.getFloat(0) === row.getFloat(0)) + assert(unsafeRow.getDouble(1) === row.getDouble(1)) + } + + test("calling get(ordinal, datatype) on null columns") { + val row = InternalRow.apply(null) + val unsafeRow = UnsafeProjection.create(Array[DataType](NullType)).apply(row) + for (dataType <- DataTypeTestUtils.atomicTypes) { + assert(unsafeRow.get(0, dataType) === null) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala index 45c9f06941c10..77ed4a9c0d5ae 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala @@ -47,17 +47,17 @@ private[sql] class MyDenseVectorUDT extends UserDefinedType[MyDenseVector] { override def sqlType: DataType = ArrayType(DoubleType, containsNull = false) - override def serialize(obj: Any): Seq[Double] = { + override def serialize(obj: Any): ArrayData = { obj match { case features: MyDenseVector => - features.data.toSeq + new GenericArrayData(features.data.map(_.asInstanceOf[Any])) } } override def deserialize(datum: Any): MyDenseVector = { datum match { - case data: Seq[_] => - new MyDenseVector(data.asInstanceOf[Seq[Double]].toArray) + case data: ArrayData => + new MyDenseVector(data.toArray.map(_.asInstanceOf[Double])) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala index 4499a7207031d..66014ddca0596 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala @@ -34,8 +34,7 @@ class ColumnStatsSuite extends SparkFunSuite { testColumnStats(classOf[DoubleColumnStats], DOUBLE, InternalRow(Double.MaxValue, Double.MinValue, 0)) testColumnStats(classOf[StringColumnStats], STRING, InternalRow(null, null, 0)) - testColumnStats(classOf[FixedDecimalColumnStats], - FIXED_DECIMAL(15, 10), InternalRow(null, null, 0)) + testDecimalColumnStats(InternalRow(null, null, 0)) def testColumnStats[T <: AtomicType, U <: ColumnStats]( columnStatsClass: Class[U], @@ -52,7 +51,7 @@ class ColumnStatsSuite extends SparkFunSuite { } test(s"$columnStatsName: non-empty") { - import ColumnarTestUtils._ + import org.apache.spark.sql.columnar.ColumnarTestUtils._ val columnStats = columnStatsClass.newInstance() val rows = Seq.fill(10)(makeRandomRow(columnType)) ++ Seq.fill(10)(makeNullRow(1)) @@ -73,4 +72,39 @@ class ColumnStatsSuite extends SparkFunSuite { } } } + + def testDecimalColumnStats[T <: AtomicType, U <: ColumnStats](initialStatistics: InternalRow) { + + val columnStatsName = classOf[FixedDecimalColumnStats].getSimpleName + val columnType = FIXED_DECIMAL(15, 10) + + test(s"$columnStatsName: empty") { + val columnStats = new FixedDecimalColumnStats(15, 10) + columnStats.collectedStatistics.toSeq.zip(initialStatistics.toSeq).foreach { + case (actual, expected) => assert(actual === expected) + } + } + + test(s"$columnStatsName: non-empty") { + import org.apache.spark.sql.columnar.ColumnarTestUtils._ + + val columnStats = new FixedDecimalColumnStats(15, 10) + val rows = Seq.fill(10)(makeRandomRow(columnType)) ++ Seq.fill(10)(makeNullRow(1)) + rows.foreach(columnStats.gatherStats(_, 0)) + + val values = rows.take(10).map(_.get(0, columnType.dataType).asInstanceOf[T#InternalType]) + val ordering = columnType.dataType.ordering.asInstanceOf[Ordering[T#InternalType]] + val stats = columnStats.collectedStatistics + + assertResult(values.min(ordering), "Wrong lower bound")(stats.genericGet(0)) + assertResult(values.max(ordering), "Wrong upper bound")(stats.genericGet(1)) + assertResult(10, "Wrong null count")(stats.genericGet(2)) + assertResult(20, "Wrong row count")(stats.genericGet(3)) + assertResult(stats.genericGet(4), "Wrong size in bytes") { + rows.map { row => + if (row.isNullAt(0)) 4 else columnType.actualSize(row, 0) + }.sum + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala index 7b75f755918c1..707cd9c6d939b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala @@ -18,8 +18,7 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.CatalystTypeConverters -import org.apache.spark.sql.catalyst.expressions.IsNull +import org.apache.spark.sql.catalyst.expressions.{Literal, IsNull} import org.apache.spark.sql.test.TestSQLContext class RowFormatConvertersSuite extends SparkPlanTest { @@ -31,7 +30,7 @@ class RowFormatConvertersSuite extends SparkPlanTest { private val outputsSafe = ExternalSort(Nil, false, PhysicalRDD(Seq.empty, null)) assert(!outputsSafe.outputsUnsafeRows) - private val outputsUnsafe = UnsafeExternalSort(Nil, false, PhysicalRDD(Seq.empty, null)) + private val outputsUnsafe = TungstenSort(Nil, false, PhysicalRDD(Seq.empty, null)) assert(outputsUnsafe.outputsUnsafeRows) test("planner should insert unsafe->safe conversions when required") { @@ -41,14 +40,14 @@ class RowFormatConvertersSuite extends SparkPlanTest { } test("filter can process unsafe rows") { - val plan = Filter(IsNull(null), outputsUnsafe) + val plan = Filter(IsNull(IsNull(Literal(1))), outputsUnsafe) val preparedPlan = TestSQLContext.prepareForExecution.execute(plan) - assert(getConverters(preparedPlan).isEmpty) + assert(getConverters(preparedPlan).size === 1) assert(preparedPlan.outputsUnsafeRows) } test("filter can process safe rows") { - val plan = Filter(IsNull(null), outputsSafe) + val plan = Filter(IsNull(IsNull(Literal(1))), outputsSafe) val preparedPlan = TestSQLContext.prepareForExecution.execute(plan) assert(getConverters(preparedPlan).isEmpty) assert(!preparedPlan.outputsUnsafeRows) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala index 6a8f394545816..f46855edfe0de 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala @@ -21,7 +21,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.test.TestSQLContext -import org.apache.spark.sql.{DataFrame, DataFrameHolder, Row} +import org.apache.spark.sql.{SQLContext, DataFrame, DataFrameHolder, Row} import scala.language.implicitConversions import scala.reflect.runtime.universe.TypeTag @@ -33,11 +33,13 @@ import scala.util.control.NonFatal */ class SparkPlanTest extends SparkFunSuite { + protected def sqlContext: SQLContext = TestSQLContext + /** * Creates a DataFrame from a local Seq of Product. */ implicit def localSeqToDataFrameHolder[A <: Product : TypeTag](data: Seq[A]): DataFrameHolder = { - TestSQLContext.implicits.localSeqToDataFrameHolder(data) + sqlContext.implicits.localSeqToDataFrameHolder(data) } /** @@ -98,7 +100,7 @@ class SparkPlanTest extends SparkFunSuite { planFunction: Seq[SparkPlan] => SparkPlan, expectedAnswer: Seq[Row], sortAnswers: Boolean = true): Unit = { - SparkPlanTest.checkAnswer(input, planFunction, expectedAnswer, sortAnswers) match { + SparkPlanTest.checkAnswer(input, planFunction, expectedAnswer, sortAnswers, sqlContext) match { case Some(errorMessage) => fail(errorMessage) case None => } @@ -121,7 +123,8 @@ class SparkPlanTest extends SparkFunSuite { planFunction: SparkPlan => SparkPlan, expectedPlanFunction: SparkPlan => SparkPlan, sortAnswers: Boolean = true): Unit = { - SparkPlanTest.checkAnswer(input, planFunction, expectedPlanFunction, sortAnswers) match { + SparkPlanTest.checkAnswer( + input, planFunction, expectedPlanFunction, sortAnswers, sqlContext) match { case Some(errorMessage) => fail(errorMessage) case None => } @@ -147,13 +150,14 @@ object SparkPlanTest { input: DataFrame, planFunction: SparkPlan => SparkPlan, expectedPlanFunction: SparkPlan => SparkPlan, - sortAnswers: Boolean): Option[String] = { + sortAnswers: Boolean, + sqlContext: SQLContext): Option[String] = { val outputPlan = planFunction(input.queryExecution.sparkPlan) val expectedOutputPlan = expectedPlanFunction(input.queryExecution.sparkPlan) val expectedAnswer: Seq[Row] = try { - executePlan(expectedOutputPlan) + executePlan(expectedOutputPlan, sqlContext) } catch { case NonFatal(e) => val errorMessage = @@ -168,7 +172,7 @@ object SparkPlanTest { } val actualAnswer: Seq[Row] = try { - executePlan(outputPlan) + executePlan(outputPlan, sqlContext) } catch { case NonFatal(e) => val errorMessage = @@ -207,12 +211,13 @@ object SparkPlanTest { input: Seq[DataFrame], planFunction: Seq[SparkPlan] => SparkPlan, expectedAnswer: Seq[Row], - sortAnswers: Boolean): Option[String] = { + sortAnswers: Boolean, + sqlContext: SQLContext): Option[String] = { val outputPlan = planFunction(input.map(_.queryExecution.sparkPlan)) val sparkAnswer: Seq[Row] = try { - executePlan(outputPlan) + executePlan(outputPlan, sqlContext) } catch { case NonFatal(e) => val errorMessage = @@ -275,10 +280,10 @@ object SparkPlanTest { } } - private def executePlan(outputPlan: SparkPlan): Seq[Row] = { + private def executePlan(outputPlan: SparkPlan, sqlContext: SQLContext): Seq[Row] = { // A very simple resolver to make writing tests easier. In contrast to the real resolver // this is always case sensitive and does not try to handle scoping or complex type resolution. - val resolvedPlan = TestSQLContext.prepareForExecution.execute( + val resolvedPlan = sqlContext.prepareForExecution.execute( outputPlan transform { case plan: SparkPlan => val inputMap = plan.children.flatMap(_.output).map(a => (a.name, a)).toMap diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala similarity index 70% rename from sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala index 7a4baa9e4a49d..450963547c798 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.types._ -class UnsafeExternalSortSuite extends SparkPlanTest with BeforeAndAfterAll { +class TungstenSortSuite extends SparkPlanTest with BeforeAndAfterAll { override def beforeAll(): Unit = { TestSQLContext.conf.setConf(SQLConf.CODEGEN_ENABLED, true) @@ -36,39 +36,21 @@ class UnsafeExternalSortSuite extends SparkPlanTest with BeforeAndAfterAll { TestSQLContext.conf.setConf(SQLConf.CODEGEN_ENABLED, SQLConf.CODEGEN_ENABLED.defaultValue.get) } - ignore("sort followed by limit should not leak memory") { - // TODO: this test is going to fail until we implement a proper iterator interface - // with a close() method. - TestSQLContext.sparkContext.conf.set("spark.unsafe.exceptionOnMemoryLeak", "false") + test("sort followed by limit") { checkThatPlansAgree( (1 to 100).map(v => Tuple1(v)).toDF("a"), - (child: SparkPlan) => Limit(10, UnsafeExternalSort('a.asc :: Nil, true, child)), + (child: SparkPlan) => Limit(10, TungstenSort('a.asc :: Nil, true, child)), (child: SparkPlan) => Limit(10, Sort('a.asc :: Nil, global = true, child)), sortAnswers = false ) } - test("sort followed by limit") { - TestSQLContext.sparkContext.conf.set("spark.unsafe.exceptionOnMemoryLeak", "false") - try { - checkThatPlansAgree( - (1 to 100).map(v => Tuple1(v)).toDF("a"), - (child: SparkPlan) => Limit(10, UnsafeExternalSort('a.asc :: Nil, true, child)), - (child: SparkPlan) => Limit(10, Sort('a.asc :: Nil, global = true, child)), - sortAnswers = false - ) - } finally { - TestSQLContext.sparkContext.conf.set("spark.unsafe.exceptionOnMemoryLeak", "false") - - } - } - test("sorting does not crash for large inputs") { val sortOrder = 'a.asc :: Nil val stringLength = 1024 * 1024 * 2 checkThatPlansAgree( Seq(Tuple1("a" * stringLength), Tuple1("b" * stringLength)).toDF("a").repartition(1), - UnsafeExternalSort(sortOrder, global = true, _: SparkPlan, testSpillFrequency = 1), + TungstenSort(sortOrder, global = true, _: SparkPlan, testSpillFrequency = 1), Sort(sortOrder, global = true, _: SparkPlan), sortAnswers = false ) @@ -88,11 +70,11 @@ class UnsafeExternalSortSuite extends SparkPlanTest with BeforeAndAfterAll { TestSQLContext.sparkContext.parallelize(Random.shuffle(inputData).map(v => Row(v))), StructType(StructField("a", dataType, nullable = true) :: Nil) ) - assert(UnsafeExternalSort.supportsSchema(inputDf.schema)) + assert(TungstenSort.supportsSchema(inputDf.schema)) checkThatPlansAgree( inputDf, plan => ConvertToSafe( - UnsafeExternalSort(sortOrder, global = true, plan: SparkPlan, testSpillFrequency = 23)), + TungstenSort(sortOrder, global = true, plan: SparkPlan, testSpillFrequency = 23)), Sort(sortOrder, global = true, _: SparkPlan), sortAnswers = false ) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala index 9dd2220f0967e..8b1a9b21a96b9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala @@ -17,11 +17,12 @@ package org.apache.spark.sql.execution.joins +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, ObjectInputStream, ObjectOutputStream} + import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.execution.SparkSqlSerializer -import org.apache.spark.sql.types.{StructField, StructType, IntegerType} +import org.apache.spark.sql.types.{IntegerType, StructField, StructType} import org.apache.spark.util.collection.CompactBuffer @@ -64,27 +65,34 @@ class HashedRelationSuite extends SparkFunSuite { } test("UnsafeHashedRelation") { + val schema = StructType(StructField("a", IntegerType, true) :: Nil) val data = Array(InternalRow(0), InternalRow(1), InternalRow(2), InternalRow(2)) + val toUnsafe = UnsafeProjection.create(schema) + val unsafeData = data.map(toUnsafe(_).copy()).toArray + val buildKey = Seq(BoundReference(0, IntegerType, false)) - val schema = StructType(StructField("a", IntegerType, true) :: Nil) - val hashed = UnsafeHashedRelation(data.iterator, buildKey, schema, 1) + val keyGenerator = UnsafeProjection.create(buildKey) + val hashed = UnsafeHashedRelation(unsafeData.iterator, keyGenerator, 1) assert(hashed.isInstanceOf[UnsafeHashedRelation]) - val toUnsafeKey = UnsafeProjection.create(schema) - val unsafeData = data.map(toUnsafeKey(_).copy()).toArray assert(hashed.get(unsafeData(0)) === CompactBuffer[InternalRow](unsafeData(0))) assert(hashed.get(unsafeData(1)) === CompactBuffer[InternalRow](unsafeData(1))) - assert(hashed.get(toUnsafeKey(InternalRow(10))) === null) + assert(hashed.get(toUnsafe(InternalRow(10))) === null) val data2 = CompactBuffer[InternalRow](unsafeData(2).copy()) data2 += unsafeData(2).copy() assert(hashed.get(unsafeData(2)) === data2) - val hashed2 = SparkSqlSerializer.deserialize(SparkSqlSerializer.serialize(hashed)) - .asInstanceOf[UnsafeHashedRelation] + val os = new ByteArrayOutputStream() + val out = new ObjectOutputStream(os) + hashed.asInstanceOf[UnsafeHashedRelation].writeExternal(out) + out.flush() + val in = new ObjectInputStream(new ByteArrayInputStream(os.toByteArray)) + val hashed2 = new UnsafeHashedRelation() + hashed2.readExternal(in) assert(hashed2.get(unsafeData(0)) === CompactBuffer[InternalRow](unsafeData(0))) assert(hashed2.get(unsafeData(1)) === CompactBuffer[InternalRow](unsafeData(1))) - assert(hashed2.get(toUnsafeKey(InternalRow(10))) === null) + assert(hashed2.get(toUnsafe(InternalRow(10))) === null) assert(hashed2.get(unsafeData(2)) === data2) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala index 3ac312d6f4c50..f19f22fca7d54 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala @@ -422,14 +422,14 @@ class JsonSuite extends QueryTest with TestJsonData { Row(-89) :: Row(21474836370L) :: Row(21474836470L) :: Nil ) - // Widening to DoubleType + // Widening to DecimalType checkAnswer( sql("select num_num_2 + 1.3 from jsonTable where num_num_2 > 1.1"), - Row(21474836472.2) :: - Row(92233720368547758071.3) :: Nil + Row(BigDecimal("21474836472.2")) :: + Row(BigDecimal("92233720368547758071.3")) :: Nil ) - // Widening to DoubleType + // Widening to Double checkAnswer( sql("select num_num_3 + 1.2 from jsonTable where num_num_3 > 1.1"), Row(101.2) :: Row(21474836471.2) :: Nil @@ -438,13 +438,13 @@ class JsonSuite extends QueryTest with TestJsonData { // Number and String conflict: resolve the type as number in this query. checkAnswer( sql("select num_str + 1.2 from jsonTable where num_str > 14"), - Row(92233720368547758071.2) + Row(BigDecimal("92233720368547758071.2")) ) // Number and String conflict: resolve the type as number in this query. checkAnswer( sql("select num_str + 1.2 from jsonTable where num_str >= 92233720368547758060"), - Row(new java.math.BigDecimal("92233720368547758071.2").doubleValue) + Row(new java.math.BigDecimal("92233720368547758071.2")) ) // String and Boolean conflict: resolve the type as string. @@ -503,7 +503,7 @@ class JsonSuite extends QueryTest with TestJsonData { // Number and String conflict: resolve the type as number in this query. checkAnswer( sql("select num_str + 1.2 from jsonTable where num_str > 13"), - Row(14.3) :: Row(92233720368547758071.2) :: Nil + Row(BigDecimal("14.3")) :: Row(BigDecimal("92233720368547758071.2")) :: Nil ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala index c037faf4cfd92..a95f70f2bba69 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala @@ -17,10 +17,13 @@ package org.apache.spark.sql.parquet +import java.io.File + import org.apache.hadoop.fs.Path import org.apache.spark.sql.types._ import org.apache.spark.sql.{QueryTest, Row, SQLConf} +import org.apache.spark.util.Utils /** * A test suite that tests various Parquet queries. @@ -123,6 +126,30 @@ class ParquetQuerySuite extends QueryTest with ParquetTest { } } + test("Enabling/disabling merging partfiles when merging parquet schema") { + def testSchemaMerging(expectedColumnNumber: Int): Unit = { + withTempDir { dir => + val basePath = dir.getCanonicalPath + sqlContext.range(0, 10).toDF("a").write.parquet(new Path(basePath, "foo=1").toString) + sqlContext.range(0, 10).toDF("b").write.parquet(new Path(basePath, "foo=2").toString) + // delete summary files, so if we don't merge part-files, one column will not be included. + Utils.deleteRecursively(new File(basePath + "/foo=1/_metadata")) + Utils.deleteRecursively(new File(basePath + "/foo=1/_common_metadata")) + assert(sqlContext.read.parquet(basePath).columns.length === expectedColumnNumber) + } + } + + withSQLConf(SQLConf.PARQUET_SCHEMA_MERGING_ENABLED.key -> "true", + SQLConf.PARQUET_SCHEMA_RESPECT_SUMMARIES.key -> "true") { + testSchemaMerging(2) + } + + withSQLConf(SQLConf.PARQUET_SCHEMA_MERGING_ENABLED.key -> "true", + SQLConf.PARQUET_SCHEMA_RESPECT_SUMMARIES.key -> "false") { + testSchemaMerging(3) + } + } + test("Enabling/disabling schema merging") { def testSchemaMerging(expectedColumnNumber: Int): Unit = { withTempDir { dir => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala index 5e189c3563ca8..cfb03ff485b7c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala @@ -67,12 +67,12 @@ case class AllDataTypesScan( override def schema: StructType = userSpecifiedSchema - override def needConversion: Boolean = false + override def needConversion: Boolean = true override def buildScan(): RDD[Row] = { sqlContext.sparkContext.parallelize(from to to).map { i => - InternalRow( - UTF8String.fromString(s"str_$i"), + Row( + s"str_$i", s"str_$i".getBytes(), i % 2 == 0, i.toByte, @@ -81,19 +81,19 @@ case class AllDataTypesScan( i.toLong, i.toFloat, i.toDouble, - Decimal(new java.math.BigDecimal(i)), - Decimal(new java.math.BigDecimal(i)), - DateTimeUtils.fromJavaDate(new Date(1970, 1, 1)), - DateTimeUtils.fromJavaTimestamp(new Timestamp(20000 + i)), - UTF8String.fromString(s"varchar_$i"), + new java.math.BigDecimal(i), + new java.math.BigDecimal(i), + new Date(1970, 1, 1), + new Timestamp(20000 + i), + s"varchar_$i", Seq(i, i + 1), - Seq(Map(UTF8String.fromString(s"str_$i") -> InternalRow(i.toLong))), - Map(i -> UTF8String.fromString(i.toString)), - Map(Map(UTF8String.fromString(s"str_$i") -> i.toFloat) -> InternalRow(i.toLong)), - InternalRow(i, UTF8String.fromString(i.toString)), - InternalRow(Seq(UTF8String.fromString(s"str_$i"), UTF8String.fromString(s"str_${i + 1}")), - InternalRow(Seq(DateTimeUtils.fromJavaDate(new Date(1970, 1, i + 1)))))) - }.asInstanceOf[RDD[Row]] + Seq(Map(s"str_$i" -> Row(i.toLong))), + Map(i -> i.toString), + Map(Map(s"str_$i" -> i.toFloat) -> Row(i.toLong)), + Row(i, i.toString), + Row(Seq(s"str_$i", s"str_${i + 1}"), + Row(Seq(new Date(1970, 1, i + 1))))) + } } } diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/SortMergeCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HashJoinCompatibilitySuite.scala similarity index 97% rename from sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/SortMergeCompatibilitySuite.scala rename to sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HashJoinCompatibilitySuite.scala index 1fe4fe9629c02..1a5ba20404c4e 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/SortMergeCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HashJoinCompatibilitySuite.scala @@ -23,16 +23,16 @@ import org.apache.spark.sql.SQLConf import org.apache.spark.sql.hive.test.TestHive /** - * Runs the test cases that are included in the hive distribution with sort merge join is true. + * Runs the test cases that are included in the hive distribution with hash joins. */ -class SortMergeCompatibilitySuite extends HiveCompatibilitySuite { +class HashJoinCompatibilitySuite extends HiveCompatibilitySuite { override def beforeAll() { super.beforeAll() - TestHive.setConf(SQLConf.SORTMERGE_JOIN, true) + TestHive.setConf(SQLConf.SORTMERGE_JOIN, false) } override def afterAll() { - TestHive.setConf(SQLConf.SORTMERGE_JOIN, false) + TestHive.setConf(SQLConf.SORTMERGE_JOIN, true) super.afterAll() } diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index b12b3838e615c..ec959cb2194b0 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -822,7 +822,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "udaf_covar_pop", "udaf_covar_samp", "udaf_histogram_numeric", - "udaf_number_format", "udf2", "udf5", "udf6", diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index f467500259c91..5926ef9aa388b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -52,9 +52,8 @@ import scala.collection.JavaConversions._ * java.sql.Timestamp * Complex Types => * Map: scala.collection.immutable.Map - * List: scala.collection.immutable.Seq - * Struct: - * [[org.apache.spark.sql.catalyst.InternalRow]] + * List: [[org.apache.spark.sql.types.ArrayData]] + * Struct: [[org.apache.spark.sql.catalyst.InternalRow]] * Union: NOT SUPPORTED YET * The Complex types plays as a container, which can hold arbitrary data types. * @@ -297,7 +296,10 @@ private[hive] trait HiveInspectors { }.toMap case li: StandardConstantListObjectInspector => // take the value from the list inspector object, rather than the input data - li.getWritableConstantValue.map(unwrap(_, li.getListElementObjectInspector)).toSeq + val values = li.getWritableConstantValue + .map(unwrap(_, li.getListElementObjectInspector)) + .toArray + new GenericArrayData(values) // if the value is null, we don't care about the object inspector type case _ if data == null => null case poi: VoidObjectInspector => null // always be null for void object inspector @@ -339,7 +341,10 @@ private[hive] trait HiveInspectors { } case li: ListObjectInspector => Option(li.getList(data)) - .map(_.map(unwrap(_, li.getListElementObjectInspector)).toSeq) + .map { l => + val values = l.map(unwrap(_, li.getListElementObjectInspector)).toArray + new GenericArrayData(values) + } .orNull case mi: MapObjectInspector => Option(mi.getMap(data)).map( @@ -391,7 +396,13 @@ private[hive] trait HiveInspectors { case loi: ListObjectInspector => val wrapper = wrapperFor(loi.getListElementObjectInspector) - (o: Any) => if (o != null) seqAsJavaList(o.asInstanceOf[Seq[_]].map(wrapper)) else null + (o: Any) => { + if (o != null) { + seqAsJavaList(o.asInstanceOf[ArrayData].toArray().map(wrapper)) + } else { + null + } + } case moi: MapObjectInspector => // The Predef.Map is scala.collection.immutable.Map. @@ -520,7 +531,7 @@ private[hive] trait HiveInspectors { case x: ListObjectInspector => val list = new java.util.ArrayList[Object] val tpe = dataType.asInstanceOf[ArrayType].elementType - a.asInstanceOf[Seq[_]].foreach { + a.asInstanceOf[ArrayData].toArray().foreach { v => list.add(wrap(v, x.getListElementObjectInspector, tpe)) } list @@ -634,7 +645,8 @@ private[hive] trait HiveInspectors { ObjectInspectorFactory.getStandardConstantListObjectInspector(listObjectInspector, null) } else { val list = new java.util.ArrayList[Object]() - value.asInstanceOf[Seq[_]].foreach(v => list.add(wrap(v, listObjectInspector, dt))) + value.asInstanceOf[ArrayData].toArray() + .foreach(v => list.add(wrap(v, listObjectInspector, dt))) ObjectInspectorFactory.getStandardConstantListObjectInspector(listObjectInspector, list) } case Literal(value, MapType(keyType, valueType, _)) => diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 3180c05445c9f..a8c9b4fa71b99 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -274,9 +274,9 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive val metastoreSchema = StructType.fromAttributes(metastoreRelation.output) val mergeSchema = hive.convertMetastoreParquetWithSchemaMerging - // NOTE: Instead of passing Metastore schema directly to `ParquetRelation2`, we have to + // NOTE: Instead of passing Metastore schema directly to `ParquetRelation`, we have to // serialize the Metastore schema to JSON and pass it as a data source option because of the - // evil case insensitivity issue, which is reconciled within `ParquetRelation2`. + // evil case insensitivity issue, which is reconciled within `ParquetRelation`. val parquetOptions = Map( ParquetRelation.METASTORE_SCHEMA -> metastoreSchema.json, ParquetRelation.MERGE_SCHEMA -> mergeSchema.toString) @@ -290,7 +290,7 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive partitionSpecInMetastore: Option[PartitionSpec]): Option[LogicalRelation] = { cachedDataSourceTables.getIfPresent(tableIdentifier) match { case null => None // Cache miss - case logical@LogicalRelation(parquetRelation: ParquetRelation) => + case logical @ LogicalRelation(parquetRelation: ParquetRelation) => // If we have the same paths, same schema, and same partition spec, // we will use the cached Parquet Relation. val useCached = diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index 2f79b0aad045c..e6df64d2642bc 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -874,15 +874,15 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C } def matchSerDe(clause: Seq[ASTNode]) - : (Seq[(String, String)], String, Seq[(String, String)]) = clause match { + : (Seq[(String, String)], Option[String], Seq[(String, String)]) = clause match { case Token("TOK_SERDEPROPS", propsClause) :: Nil => val rowFormat = propsClause.map { case Token(name, Token(value, Nil) :: Nil) => (name, value) } - (rowFormat, "", Nil) + (rowFormat, None, Nil) case Token("TOK_SERDENAME", Token(serdeClass, Nil) :: Nil) :: Nil => - (Nil, serdeClass, Nil) + (Nil, Some(BaseSemanticAnalyzer.unescapeSQLString(serdeClass)), Nil) case Token("TOK_SERDENAME", Token(serdeClass, Nil) :: Token("TOK_TABLEPROPERTIES", @@ -891,9 +891,9 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C case Token("TOK_TABLEPROPERTY", Token(name, Nil) :: Token(value, Nil) :: Nil) => (name, value) } - (Nil, serdeClass, serdeProps) + (Nil, Some(BaseSemanticAnalyzer.unescapeSQLString(serdeClass)), serdeProps) - case Nil => (Nil, "", Nil) + case Nil => (Nil, None, Nil) } val (inRowFormat, inSerdeClass, inSerdeProps) = matchSerDe(inputSerdeClause) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala index 205e622195f09..7e3342cc84c0e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala @@ -17,15 +17,18 @@ package org.apache.spark.sql.hive.execution -import java.io.{BufferedReader, DataInputStream, DataOutputStream, EOFException, InputStreamReader} +import java.io._ import java.util.Properties +import javax.annotation.Nullable import scala.collection.JavaConversions._ +import scala.util.control.NonFatal import org.apache.hadoop.hive.serde.serdeConstants import org.apache.hadoop.hive.serde2.AbstractSerDe import org.apache.hadoop.hive.serde2.objectinspector._ +import org.apache.spark.{TaskContext, Logging} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.CatalystTypeConverters @@ -56,21 +59,53 @@ case class ScriptTransformation( override def otherCopyArgs: Seq[HiveContext] = sc :: Nil protected override def doExecute(): RDD[InternalRow] = { - child.execute().mapPartitions { iter => + def processIterator(inputIterator: Iterator[InternalRow]): Iterator[InternalRow] = { val cmd = List("/bin/bash", "-c", script) val builder = new ProcessBuilder(cmd) - // We need to start threads connected to the process pipeline: - // 1) The error msg generated by the script process would be hidden. - // 2) If the error msg is too big to chock up the buffer, the input logic would be hung + val proc = builder.start() val inputStream = proc.getInputStream val outputStream = proc.getOutputStream val errorStream = proc.getErrorStream - val reader = new BufferedReader(new InputStreamReader(inputStream)) - val (outputSerde, outputSoi) = ioschema.initOutputSerDe(output) + // In order to avoid deadlocks, we need to consume the error output of the child process. + // To avoid issues caused by large error output, we use a circular buffer to limit the amount + // of error output that we retain. See SPARK-7862 for more discussion of the deadlock / hang + // that motivates this. + val stderrBuffer = new CircularBuffer(2048) + new RedirectThread( + errorStream, + stderrBuffer, + "Thread-ScriptTransformation-STDERR-Consumer").start() + + val outputProjection = new InterpretedProjection(input, child.output) + + // This nullability is a performance optimization in order to avoid an Option.foreach() call + // inside of a loop + @Nullable val (inputSerde, inputSoi) = ioschema.initInputSerDe(input).getOrElse((null, null)) + + // This new thread will consume the ScriptTransformation's input rows and write them to the + // external process. That process's output will be read by this current thread. + val writerThread = new ScriptTransformationWriterThread( + inputIterator, + outputProjection, + inputSerde, + inputSoi, + ioschema, + outputStream, + proc, + stderrBuffer, + TaskContext.get() + ) + + // This nullability is a performance optimization in order to avoid an Option.foreach() call + // inside of a loop + @Nullable val (outputSerde, outputSoi) = { + ioschema.initOutputSerDe(output).getOrElse((null, null)) + } - val iterator: Iterator[InternalRow] = new Iterator[InternalRow] with HiveInspectors { + val reader = new BufferedReader(new InputStreamReader(inputStream)) + val outputIterator: Iterator[InternalRow] = new Iterator[InternalRow] with HiveInspectors { var cacheRow: InternalRow = null var curLine: String = null var eof: Boolean = false @@ -79,12 +114,26 @@ case class ScriptTransformation( if (outputSerde == null) { if (curLine == null) { curLine = reader.readLine() - curLine != null + if (curLine == null) { + if (writerThread.exception.isDefined) { + throw writerThread.exception.get + } + false + } else { + true + } } else { true } } else { - !eof + if (eof) { + if (writerThread.exception.isDefined) { + throw writerThread.exception.get + } + false + } else { + true + } } } @@ -110,11 +159,11 @@ case class ScriptTransformation( } i += 1 }) - return mutableRow + mutableRow } catch { case e: EOFException => eof = true - return null + null } } @@ -127,13 +176,13 @@ case class ScriptTransformation( val prevLine = curLine curLine = reader.readLine() if (!ioschema.schemaLess) { - new GenericInternalRow(CatalystTypeConverters.convertToCatalyst( - prevLine.split(ioschema.outputRowFormatMap("TOK_TABLEROWFORMATFIELD"))) - .asInstanceOf[Array[Any]]) + new GenericInternalRow( + prevLine.split(ioschema.outputRowFormatMap("TOK_TABLEROWFORMATFIELD")) + .map(CatalystTypeConverters.convertToCatalyst)) } else { - new GenericInternalRow(CatalystTypeConverters.convertToCatalyst( - prevLine.split(ioschema.outputRowFormatMap("TOK_TABLEROWFORMATFIELD"), 2)) - .asInstanceOf[Array[Any]]) + new GenericInternalRow( + prevLine.split(ioschema.outputRowFormatMap("TOK_TABLEROWFORMATFIELD"), 2) + .map(CatalystTypeConverters.convertToCatalyst)) } } else { val ret = deserialize() @@ -146,49 +195,83 @@ case class ScriptTransformation( } } - val (inputSerde, inputSoi) = ioschema.initInputSerDe(input) - val dataOutputStream = new DataOutputStream(outputStream) - val outputProjection = new InterpretedProjection(input, child.output) + writerThread.start() - // TODO make the 2048 configurable? - val stderrBuffer = new CircularBuffer(2048) - // Consume the error stream from the pipeline, otherwise it will be blocked if - // the pipeline is full. - new RedirectThread(errorStream, // input stream from the pipeline - stderrBuffer, // output to a circular buffer - "Thread-ScriptTransformation-STDERR-Consumer").start() + outputIterator + } - // Put the write(output to the pipeline) into a single thread - // and keep the collector as remain in the main thread. - // otherwise it will causes deadlock if the data size greater than - // the pipeline / buffer capacity. - new Thread(new Runnable() { - override def run(): Unit = { - Utils.tryWithSafeFinally { - iter - .map(outputProjection) - .foreach { row => - if (inputSerde == null) { - val data = row.mkString("", ioschema.inputRowFormatMap("TOK_TABLEROWFORMATFIELD"), - ioschema.inputRowFormatMap("TOK_TABLEROWFORMATLINES")).getBytes("utf-8") - - outputStream.write(data) - } else { - val writable = inputSerde.serialize( - row.asInstanceOf[GenericInternalRow].values, inputSoi) - prepareWritable(writable).write(dataOutputStream) - } - } - outputStream.close() - } { - if (proc.waitFor() != 0) { - logError(stderrBuffer.toString) // log the stderr circular buffer - } - } - } - }, "Thread-ScriptTransformation-Feed").start() + child.execute().mapPartitions { iter => + if (iter.hasNext) { + processIterator(iter) + } else { + // If the input iterator has no rows then do not launch the external script. + Iterator.empty + } + } + } +} - iterator +private class ScriptTransformationWriterThread( + iter: Iterator[InternalRow], + outputProjection: Projection, + @Nullable inputSerde: AbstractSerDe, + @Nullable inputSoi: ObjectInspector, + ioschema: HiveScriptIOSchema, + outputStream: OutputStream, + proc: Process, + stderrBuffer: CircularBuffer, + taskContext: TaskContext + ) extends Thread("Thread-ScriptTransformation-Feed") with Logging { + + setDaemon(true) + + @volatile private var _exception: Throwable = null + + /** Contains the exception thrown while writing the parent iterator to the external process. */ + def exception: Option[Throwable] = Option(_exception) + + override def run(): Unit = Utils.logUncaughtExceptions { + TaskContext.setTaskContext(taskContext) + + val dataOutputStream = new DataOutputStream(outputStream) + + // We can't use Utils.tryWithSafeFinally here because we also need a `catch` block, so + // let's use a variable to record whether the `finally` block was hit due to an exception + var threwException: Boolean = true + try { + iter.map(outputProjection).foreach { row => + if (inputSerde == null) { + val data = row.mkString("", ioschema.inputRowFormatMap("TOK_TABLEROWFORMATFIELD"), + ioschema.inputRowFormatMap("TOK_TABLEROWFORMATLINES")).getBytes("utf-8") + outputStream.write(data) + } else { + val writable = inputSerde.serialize( + row.asInstanceOf[GenericInternalRow].values, inputSoi) + prepareWritable(writable).write(dataOutputStream) + } + } + outputStream.close() + threwException = false + } catch { + case NonFatal(e) => + // An error occurred while writing input, so kill the child process. According to the + // Javadoc this call will not throw an exception: + _exception = e + proc.destroy() + throw e + } finally { + try { + if (proc.waitFor() != 0) { + logError(stderrBuffer.toString) // log the stderr circular buffer + } + } catch { + case NonFatal(exceptionFromFinallyBlock) => + if (!threwException) { + throw exceptionFromFinallyBlock + } else { + log.error("Exception in finally block", exceptionFromFinallyBlock) + } + } } } } @@ -200,33 +283,43 @@ private[hive] case class HiveScriptIOSchema ( inputRowFormat: Seq[(String, String)], outputRowFormat: Seq[(String, String)], - inputSerdeClass: String, - outputSerdeClass: String, + inputSerdeClass: Option[String], + outputSerdeClass: Option[String], inputSerdeProps: Seq[(String, String)], outputSerdeProps: Seq[(String, String)], schemaLess: Boolean) extends ScriptInputOutputSchema with HiveInspectors { - val defaultFormat = Map(("TOK_TABLEROWFORMATFIELD", "\t"), - ("TOK_TABLEROWFORMATLINES", "\n")) + private val defaultFormat = Map( + ("TOK_TABLEROWFORMATFIELD", "\t"), + ("TOK_TABLEROWFORMATLINES", "\n") + ) val inputRowFormatMap = inputRowFormat.toMap.withDefault((k) => defaultFormat(k)) val outputRowFormatMap = outputRowFormat.toMap.withDefault((k) => defaultFormat(k)) - def initInputSerDe(input: Seq[Expression]): (AbstractSerDe, ObjectInspector) = { - val (columns, columnTypes) = parseAttrs(input) - val serde = initSerDe(inputSerdeClass, columns, columnTypes, inputSerdeProps) - (serde, initInputSoi(serde, columns, columnTypes)) + def initInputSerDe(input: Seq[Expression]): Option[(AbstractSerDe, ObjectInspector)] = { + inputSerdeClass.map { serdeClass => + val (columns, columnTypes) = parseAttrs(input) + val serde = initSerDe(serdeClass, columns, columnTypes, inputSerdeProps) + val fieldObjectInspectors = columnTypes.map(toInspector) + val objectInspector = ObjectInspectorFactory + .getStandardStructObjectInspector(columns, fieldObjectInspectors) + .asInstanceOf[ObjectInspector] + (serde, objectInspector) + } } - def initOutputSerDe(output: Seq[Attribute]): (AbstractSerDe, StructObjectInspector) = { - val (columns, columnTypes) = parseAttrs(output) - val serde = initSerDe(outputSerdeClass, columns, columnTypes, outputSerdeProps) - (serde, initOutputputSoi(serde)) + def initOutputSerDe(output: Seq[Attribute]): Option[(AbstractSerDe, StructObjectInspector)] = { + outputSerdeClass.map { serdeClass => + val (columns, columnTypes) = parseAttrs(output) + val serde = initSerDe(serdeClass, columns, columnTypes, outputSerdeProps) + val structObjectInspector = serde.getObjectInspector().asInstanceOf[StructObjectInspector] + (serde, structObjectInspector) + } } - def parseAttrs(attrs: Seq[Expression]): (Seq[String], Seq[DataType]) = { - + private def parseAttrs(attrs: Seq[Expression]): (Seq[String], Seq[DataType]) = { val columns = attrs.map { case aref: AttributeReference => aref.name case e: NamedExpression => e.name @@ -242,52 +335,25 @@ case class HiveScriptIOSchema ( (columns, columnTypes) } - def initSerDe(serdeClassName: String, columns: Seq[String], - columnTypes: Seq[DataType], serdeProps: Seq[(String, String)]): AbstractSerDe = { + private def initSerDe( + serdeClassName: String, + columns: Seq[String], + columnTypes: Seq[DataType], + serdeProps: Seq[(String, String)]): AbstractSerDe = { - val serde: AbstractSerDe = if (serdeClassName != "") { - val trimed_class = serdeClassName.split("'")(1) - Utils.classForName(trimed_class) - .newInstance.asInstanceOf[AbstractSerDe] - } else { - null - } + val serde = Utils.classForName(serdeClassName).newInstance.asInstanceOf[AbstractSerDe] - if (serde != null) { - val columnTypesNames = columnTypes.map(_.toTypeInfo.getTypeName()).mkString(",") + val columnTypesNames = columnTypes.map(_.toTypeInfo.getTypeName()).mkString(",") - var propsMap = serdeProps.map(kv => { - (kv._1.split("'")(1), kv._2.split("'")(1)) - }).toMap + (serdeConstants.LIST_COLUMNS -> columns.mkString(",")) - propsMap = propsMap + (serdeConstants.LIST_COLUMN_TYPES -> columnTypesNames) + var propsMap = serdeProps.map(kv => { + (kv._1.split("'")(1), kv._2.split("'")(1)) + }).toMap + (serdeConstants.LIST_COLUMNS -> columns.mkString(",")) + propsMap = propsMap + (serdeConstants.LIST_COLUMN_TYPES -> columnTypesNames) - val properties = new Properties() - properties.putAll(propsMap) - serde.initialize(null, properties) - } + val properties = new Properties() + properties.putAll(propsMap) + serde.initialize(null, properties) serde } - - def initInputSoi(inputSerde: AbstractSerDe, columns: Seq[String], columnTypes: Seq[DataType]) - : ObjectInspector = { - - if (inputSerde != null) { - val fieldObjectInspectors = columnTypes.map(toInspector(_)) - ObjectInspectorFactory - .getStandardStructObjectInspector(columns, fieldObjectInspectors) - .asInstanceOf[ObjectInspector] - } else { - null - } - } - - def initOutputputSoi(outputSerde: AbstractSerDe): StructObjectInspector = { - if (outputSerde != null) { - outputSerde.getObjectInspector().asInstanceOf[StructObjectInspector] - } else { - null - } - } } - diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala index 8732e9abf8d31..4a13022eddf60 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala @@ -431,7 +431,7 @@ private[hive] case class HiveWindowFunction( // if pivotResult is true, we will get a Seq having the same size with the size // of the window frame. At here, we will return the result at the position of // index in the output buffer. - outputBuffer.asInstanceOf[Seq[Any]].get(index) + outputBuffer.asInstanceOf[ArrayData].get(index) } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index 3662a4352f55d..7bbdef90cd6b9 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -56,6 +56,7 @@ object TestHive .set("spark.sql.test", "") .set("spark.sql.hive.metastore.barrierPrefixes", "org.apache.spark.sql.hive.execution.PairSerDe") + .set("spark.buffer.pageSize", "4m") // SPARK-8910 .set("spark.ui.enabled", "false"))) diff --git a/sql/hive/src/test/resources/golden/udaf_number_format-0-eff4ef3c207d14d5121368f294697964 b/sql/hive/src/test/resources/golden/udaf_number_format-0-eff4ef3c207d14d5121368f294697964 deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/sql/hive/src/test/resources/golden/udaf_number_format-1-4a03c4328565c60ca99689239f07fb16 b/sql/hive/src/test/resources/golden/udaf_number_format-1-4a03c4328565c60ca99689239f07fb16 deleted file mode 100644 index c6f275a0db131..0000000000000 --- a/sql/hive/src/test/resources/golden/udaf_number_format-1-4a03c4328565c60ca99689239f07fb16 +++ /dev/null @@ -1 +0,0 @@ -0.0 NULL NULL NULL diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala index 0330013f5325e..f719f2e06ab63 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala @@ -217,7 +217,7 @@ class HiveInspectorSuite extends SparkFunSuite with HiveInspectors { test("wrap / unwrap Array Type") { val dt = ArrayType(dataTypes(0)) - val d = row(0) :: row(0) :: Nil + val d = new GenericArrayData(Array(row(0), row(0))) checkValue(d, unwrap(wrap(d, toInspector(dt), dt), toInspector(dt))) checkValue(null, unwrap(wrap(null, toInspector(dt), dt), toInspector(dt))) checkValue(d, diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index f067ea0d4fc75..bc72b0172a467 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -172,7 +172,7 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { bhj = df.queryExecution.sparkPlan.collect { case j: BroadcastHashJoin => j } assert(bhj.isEmpty, "BroadcastHashJoin still planned even though it is switched off") - val shj = df.queryExecution.sparkPlan.collect { case j: ShuffledHashJoin => j } + val shj = df.queryExecution.sparkPlan.collect { case j: SortMergeJoin => j } assert(shj.size === 1, "ShuffledHashJoin should be planned when BroadcastHashJoin is turned off") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala index 4056dee777574..9b3ede43ee2d1 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala @@ -17,13 +17,14 @@ package org.apache.spark.sql.hive -import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.{Row, QueryTest} case class FunctionResult(f1: String, f2: String) class UDFSuite extends QueryTest { private lazy val ctx = org.apache.spark.sql.hive.test.TestHive + import ctx.implicits._ test("UDF case insensitive") { ctx.udf.register("random0", () => { Math.random() }) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala new file mode 100644 index 0000000000000..0875232aede3e --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.execution + +import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe +import org.scalatest.exceptions.TestFailedException + +import org.apache.spark.TaskContext +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} +import org.apache.spark.sql.execution.{UnaryNode, SparkPlan, SparkPlanTest} +import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.types.StringType + +class ScriptTransformationSuite extends SparkPlanTest { + + override def sqlContext: SQLContext = TestHive + + private val noSerdeIOSchema = HiveScriptIOSchema( + inputRowFormat = Seq.empty, + outputRowFormat = Seq.empty, + inputSerdeClass = None, + outputSerdeClass = None, + inputSerdeProps = Seq.empty, + outputSerdeProps = Seq.empty, + schemaLess = false + ) + + private val serdeIOSchema = noSerdeIOSchema.copy( + inputSerdeClass = Some(classOf[LazySimpleSerDe].getCanonicalName), + outputSerdeClass = Some(classOf[LazySimpleSerDe].getCanonicalName) + ) + + test("cat without SerDe") { + val rowsDf = Seq("a", "b", "c").map(Tuple1.apply).toDF("a") + checkAnswer( + rowsDf, + (child: SparkPlan) => new ScriptTransformation( + input = Seq(rowsDf.col("a").expr), + script = "cat", + output = Seq(AttributeReference("a", StringType)()), + child = child, + ioschema = noSerdeIOSchema + )(TestHive), + rowsDf.collect()) + } + + test("cat with LazySimpleSerDe") { + val rowsDf = Seq("a", "b", "c").map(Tuple1.apply).toDF("a") + checkAnswer( + rowsDf, + (child: SparkPlan) => new ScriptTransformation( + input = Seq(rowsDf.col("a").expr), + script = "cat", + output = Seq(AttributeReference("a", StringType)()), + child = child, + ioschema = serdeIOSchema + )(TestHive), + rowsDf.collect()) + } + + test("script transformation should not swallow errors from upstream operators (no serde)") { + val rowsDf = Seq("a", "b", "c").map(Tuple1.apply).toDF("a") + val e = intercept[TestFailedException] { + checkAnswer( + rowsDf, + (child: SparkPlan) => new ScriptTransformation( + input = Seq(rowsDf.col("a").expr), + script = "cat", + output = Seq(AttributeReference("a", StringType)()), + child = ExceptionInjectingOperator(child), + ioschema = noSerdeIOSchema + )(TestHive), + rowsDf.collect()) + } + assert(e.getMessage().contains("intentional exception")) + } + + test("script transformation should not swallow errors from upstream operators (with serde)") { + val rowsDf = Seq("a", "b", "c").map(Tuple1.apply).toDF("a") + val e = intercept[TestFailedException] { + checkAnswer( + rowsDf, + (child: SparkPlan) => new ScriptTransformation( + input = Seq(rowsDf.col("a").expr), + script = "cat", + output = Seq(AttributeReference("a", StringType)()), + child = ExceptionInjectingOperator(child), + ioschema = serdeIOSchema + )(TestHive), + rowsDf.collect()) + } + assert(e.getMessage().contains("intentional exception")) + } +} + +private case class ExceptionInjectingOperator(child: SparkPlan) extends UnaryNode { + override protected def doExecute(): RDD[InternalRow] = { + child.execute().map { x => + assert(TaskContext.get() != null) // Make sure that TaskContext is defined. + Thread.sleep(1000) // This sleep gives the external process time to start. + throw new IllegalArgumentException("intentional exception") + } + } + override def output: Seq[Attribute] = child.output +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala index 65d4e933bf8e9..2780d5b6adbcf 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala @@ -25,6 +25,7 @@ import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.conf.Configuration import org.apache.spark.{SparkException, SparkConf, Logging} +import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.io.CompressionCodec import org.apache.spark.util.{MetadataCleaner, Utils} import org.apache.spark.streaming.scheduler.JobGenerator @@ -100,7 +101,7 @@ object Checkpoint extends Logging { } val path = new Path(checkpointDir) - val fs = fsOption.getOrElse(path.getFileSystem(new Configuration())) + val fs = fsOption.getOrElse(path.getFileSystem(SparkHadoopUtil.get.conf)) if (fs.exists(path)) { val statuses = fs.listStatus(path) if (statuses != null) { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala index 92438f1b1fbf7..177e710ace54b 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala @@ -34,6 +34,7 @@ import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} import org.apache.spark._ import org.apache.spark.annotation.{DeveloperApi, Experimental} +import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.input.FixedLengthBinaryInputFormat import org.apache.spark.rdd.{RDD, RDDOperationScope} import org.apache.spark.serializer.SerializationDebugger @@ -110,7 +111,7 @@ class StreamingContext private[streaming] ( * Recreate a StreamingContext from a checkpoint file. * @param path Path to the directory that was specified as the checkpoint directory */ - def this(path: String) = this(path, new Configuration) + def this(path: String) = this(path, SparkHadoopUtil.get.conf) /** * Recreate a StreamingContext from a checkpoint file using an existing SparkContext. @@ -803,7 +804,7 @@ object StreamingContext extends Logging { def getActiveOrCreate( checkpointPath: String, creatingFunc: () => StreamingContext, - hadoopConf: Configuration = new Configuration(), + hadoopConf: Configuration = SparkHadoopUtil.get.conf, createOnError: Boolean = false ): StreamingContext = { ACTIVATION_LOCK.synchronized { @@ -828,7 +829,7 @@ object StreamingContext extends Logging { def getOrCreate( checkpointPath: String, creatingFunc: () => StreamingContext, - hadoopConf: Configuration = new Configuration(), + hadoopConf: Configuration = SparkHadoopUtil.get.conf, createOnError: Boolean = false ): StreamingContext = { val checkpointOption = CheckpointReader.read( diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala index 959ac9c177f81..26383e420101e 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala @@ -788,7 +788,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( keyClass: Class[_], valueClass: Class[_], outputFormatClass: Class[F], - conf: Configuration = new Configuration) { + conf: Configuration = dstream.context.sparkContext.hadoopConfiguration) { dstream.saveAsNewAPIHadoopFiles(prefix, suffix, keyClass, valueClass, outputFormatClass, conf) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala index 40deb6d7ea79a..35cc3ce5cf468 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala @@ -33,6 +33,7 @@ import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.{JavaPairRDD, JavaRDD, JavaSparkContext} import org.apache.spark.api.java.function.{Function => JFunction, Function2 => JFunction2} import org.apache.spark.api.java.function.{Function0 => JFunction0} +import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming._ @@ -136,7 +137,7 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { * Recreate a JavaStreamingContext from a checkpoint file. * @param path Path to the directory that was specified as the checkpoint directory */ - def this(path: String) = this(new StreamingContext(path, new Configuration)) + def this(path: String) = this(new StreamingContext(path, SparkHadoopUtil.get.conf)) /** * Re-creates a JavaStreamingContext from a checkpoint file. diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala index d58c99a8ff321..a6c4cd220e42f 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala @@ -21,7 +21,9 @@ import scala.reflect.ClassTag import org.apache.spark.SparkContext import org.apache.spark.rdd.RDDOperationScope -import org.apache.spark.streaming.{Time, Duration, StreamingContext} +import org.apache.spark.streaming.{Duration, StreamingContext, Time} +import org.apache.spark.streaming.scheduler.RateController +import org.apache.spark.streaming.scheduler.rate.RateEstimator import org.apache.spark.util.Utils /** @@ -47,6 +49,9 @@ abstract class InputDStream[T: ClassTag] (@transient ssc_ : StreamingContext) /** This is an unique identifier for the input stream. */ val id = ssc.getNewInputStreamId() + // Keep track of the freshest rate for this stream using the rateEstimator + protected[streaming] val rateController: Option[RateController] = None + /** A human-readable name of this InputDStream */ private[streaming] def name: String = { // e.g. FlumePollingDStream -> "Flume polling stream" diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala index a50f0efc030ce..646a8c3530a62 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala @@ -21,10 +21,11 @@ import scala.reflect.ClassTag import org.apache.spark.rdd.{BlockRDD, RDD} import org.apache.spark.storage.BlockId -import org.apache.spark.streaming._ +import org.apache.spark.streaming.{StreamingContext, Time} import org.apache.spark.streaming.rdd.WriteAheadLogBackedBlockRDD import org.apache.spark.streaming.receiver.Receiver -import org.apache.spark.streaming.scheduler.StreamInputInfo +import org.apache.spark.streaming.scheduler.{RateController, StreamInputInfo} +import org.apache.spark.streaming.scheduler.rate.RateEstimator import org.apache.spark.streaming.util.WriteAheadLogUtils /** @@ -40,6 +41,17 @@ import org.apache.spark.streaming.util.WriteAheadLogUtils abstract class ReceiverInputDStream[T: ClassTag](@transient ssc_ : StreamingContext) extends InputDStream[T](ssc_) { + /** + * Asynchronously maintains & sends new rate limits to the receiver through the receiver tracker. + */ + override protected[streaming] val rateController: Option[RateController] = { + if (RateController.isBackPressureEnabled(ssc.conf)) { + RateEstimator.create(ssc.conf).map { new ReceiverRateController(id, _) } + } else { + None + } + } + /** * Gets the receiver object that will be sent to the worker nodes * to receive data. This method needs to defined by any specific implementation @@ -110,4 +122,14 @@ abstract class ReceiverInputDStream[T: ClassTag](@transient ssc_ : StreamingCont } Some(blockRDD) } + + /** + * A RateController that sends the new rate to receivers, via the receiver tracker. + */ + private[streaming] class ReceiverRateController(id: Int, estimator: RateEstimator) + extends RateController(id, estimator) { + override def publish(rate: Long): Unit = + ssc.scheduler.receiverTracker.sendRateUpdate(id, rate) + } } + diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala index 4af9b6d3b56ab..58bdda7794bf2 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala @@ -66,6 +66,12 @@ class JobScheduler(val ssc: StreamingContext) extends Logging { } eventLoop.start() + // attach rate controllers of input streams to receive batch completion updates + for { + inputDStream <- ssc.graph.getInputStreams + rateController <- inputDStream.rateController + } ssc.addStreamingListener(rateController) + listenerBus.start(ssc.sparkContext) receiverTracker = new ReceiverTracker(ssc) inputInfoTracker = new InputInfoTracker(ssc) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/RateController.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/RateController.scala new file mode 100644 index 0000000000000..882ca0676b6ad --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/RateController.scala @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.scheduler + +import java.io.ObjectInputStream +import java.util.concurrent.atomic.AtomicLong + +import scala.concurrent.{ExecutionContext, Future} + +import org.apache.spark.SparkConf +import org.apache.spark.streaming.scheduler.rate.RateEstimator +import org.apache.spark.util.{ThreadUtils, Utils} + +/** + * A StreamingListener that receives batch completion updates, and maintains + * an estimate of the speed at which this stream should ingest messages, + * given an estimate computation from a `RateEstimator` + */ +private[streaming] abstract class RateController(val streamUID: Int, rateEstimator: RateEstimator) + extends StreamingListener with Serializable { + + init() + + protected def publish(rate: Long): Unit + + @transient + implicit private var executionContext: ExecutionContext = _ + + @transient + private var rateLimit: AtomicLong = _ + + /** + * An initialization method called both from the constructor and Serialization code. + */ + private def init() { + executionContext = ExecutionContext.fromExecutorService( + ThreadUtils.newDaemonSingleThreadExecutor("stream-rate-update")) + rateLimit = new AtomicLong(-1L) + } + + private def readObject(ois: ObjectInputStream): Unit = Utils.tryOrIOException { + ois.defaultReadObject() + init() + } + + /** + * Compute the new rate limit and publish it asynchronously. + */ + private def computeAndPublish(time: Long, elems: Long, workDelay: Long, waitDelay: Long): Unit = + Future[Unit] { + val newRate = rateEstimator.compute(time, elems, workDelay, waitDelay) + newRate.foreach { s => + rateLimit.set(s.toLong) + publish(getLatestRate()) + } + } + + def getLatestRate(): Long = rateLimit.get() + + override def onBatchCompleted(batchCompleted: StreamingListenerBatchCompleted) { + val elements = batchCompleted.batchInfo.streamIdToInputInfo + + for { + processingEnd <- batchCompleted.batchInfo.processingEndTime; + workDelay <- batchCompleted.batchInfo.processingDelay; + waitDelay <- batchCompleted.batchInfo.schedulingDelay; + elems <- elements.get(streamUID).map(_.numRecords) + } computeAndPublish(processingEnd, elems, workDelay, waitDelay) + } +} + +object RateController { + def isBackPressureEnabled(conf: SparkConf): Boolean = + conf.getBoolean("spark.streaming.backpressure.enable", false) +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala index 6270137951b5a..e076fb5ea174b 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala @@ -223,7 +223,11 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false // Signal the receivers to delete old block data if (WriteAheadLogUtils.enableReceiverLog(ssc.conf)) { logInfo(s"Cleanup old received batch data: $cleanupThreshTime") - endpoint.send(CleanupOldBlocks(cleanupThreshTime)) + synchronized { + if (isTrackerStarted) { + endpoint.send(CleanupOldBlocks(cleanupThreshTime)) + } + } } } @@ -285,8 +289,10 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false } /** Update a receiver's maximum ingestion rate */ - def sendRateUpdate(streamUID: Int, newRate: Long): Unit = { - endpoint.send(UpdateReceiverRateLimit(streamUID, newRate)) + def sendRateUpdate(streamUID: Int, newRate: Long): Unit = synchronized { + if (isTrackerStarted) { + endpoint.send(UpdateReceiverRateLimit(streamUID, newRate)) + } } /** Add new blocks for the given stream */ diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/RateEstimator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/RateEstimator.scala new file mode 100644 index 0000000000000..a08685119e5d5 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/RateEstimator.scala @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.scheduler.rate + +import org.apache.spark.SparkConf +import org.apache.spark.SparkException + +/** + * A component that estimates the rate at wich an InputDStream should ingest + * elements, based on updates at every batch completion. + */ +private[streaming] trait RateEstimator extends Serializable { + + /** + * Computes the number of elements the stream attached to this `RateEstimator` + * should ingest per second, given an update on the size and completion + * times of the latest batch. + * + * @param time The timetamp of the current batch interval that just finished + * @param elements The number of elements that were processed in this batch + * @param processingDelay The time in ms that took for the job to complete + * @param schedulingDelay The time in ms that the job spent in the scheduling queue + */ + def compute( + time: Long, + elements: Long, + processingDelay: Long, + schedulingDelay: Long): Option[Double] +} + +object RateEstimator { + + /** + * Return a new RateEstimator based on the value of `spark.streaming.RateEstimator`. + * + * @return None if there is no configured estimator, otherwise an instance of RateEstimator + * @throws IllegalArgumentException if there is a configured RateEstimator that doesn't match any + * known estimators. + */ + def create(conf: SparkConf): Option[RateEstimator] = + conf.getOption("spark.streaming.backpressure.rateEstimator").map { estimator => + throw new IllegalArgumentException(s"Unkown rate estimator: $estimator") + } +} diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java index a34f23475804a..e0718f73aa13f 100644 --- a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java +++ b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java @@ -1735,6 +1735,7 @@ public Integer call(String s) throws Exception { @SuppressWarnings("unchecked") @Test public void testContextGetOrCreate() throws InterruptedException { + ssc.stop(); final SparkConf conf = new SparkConf() .setMaster("local[2]") diff --git a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala index 08faeaa58f419..255376807c957 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala @@ -81,39 +81,41 @@ class BasicOperationsSuite extends TestSuiteBase { test("repartition (more partitions)") { val input = Seq(1 to 100, 101 to 200, 201 to 300) val operation = (r: DStream[Int]) => r.repartition(5) - val ssc = setupStreams(input, operation, 2) - val output = runStreamsWithPartitions(ssc, 3, 3) - assert(output.size === 3) - val first = output(0) - val second = output(1) - val third = output(2) - - assert(first.size === 5) - assert(second.size === 5) - assert(third.size === 5) - - assert(first.flatten.toSet.equals((1 to 100).toSet) ) - assert(second.flatten.toSet.equals((101 to 200).toSet)) - assert(third.flatten.toSet.equals((201 to 300).toSet)) + withStreamingContext(setupStreams(input, operation, 2)) { ssc => + val output = runStreamsWithPartitions(ssc, 3, 3) + assert(output.size === 3) + val first = output(0) + val second = output(1) + val third = output(2) + + assert(first.size === 5) + assert(second.size === 5) + assert(third.size === 5) + + assert(first.flatten.toSet.equals((1 to 100).toSet)) + assert(second.flatten.toSet.equals((101 to 200).toSet)) + assert(third.flatten.toSet.equals((201 to 300).toSet)) + } } test("repartition (fewer partitions)") { val input = Seq(1 to 100, 101 to 200, 201 to 300) val operation = (r: DStream[Int]) => r.repartition(2) - val ssc = setupStreams(input, operation, 5) - val output = runStreamsWithPartitions(ssc, 3, 3) - assert(output.size === 3) - val first = output(0) - val second = output(1) - val third = output(2) - - assert(first.size === 2) - assert(second.size === 2) - assert(third.size === 2) - - assert(first.flatten.toSet.equals((1 to 100).toSet)) - assert(second.flatten.toSet.equals( (101 to 200).toSet)) - assert(third.flatten.toSet.equals((201 to 300).toSet)) + withStreamingContext(setupStreams(input, operation, 5)) { ssc => + val output = runStreamsWithPartitions(ssc, 3, 3) + assert(output.size === 3) + val first = output(0) + val second = output(1) + val third = output(2) + + assert(first.size === 2) + assert(second.size === 2) + assert(third.size === 2) + + assert(first.flatten.toSet.equals((1 to 100).toSet)) + assert(second.flatten.toSet.equals((101 to 200).toSet)) + assert(third.flatten.toSet.equals((201 to 300).toSet)) + } } test("groupByKey") { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala index d308ac05a54fe..67c2d900940ab 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala @@ -30,8 +30,10 @@ import org.apache.hadoop.io.{IntWritable, Text} import org.apache.hadoop.mapred.TextOutputFormat import org.apache.hadoop.mapreduce.lib.output.{TextOutputFormat => NewTextOutputFormat} import org.scalatest.concurrent.Eventually._ +import org.scalatest.time.SpanSugar._ import org.apache.spark.streaming.dstream.{DStream, FileInputDStream} +import org.apache.spark.streaming.scheduler.{RateLimitInputDStream, ConstantEstimator, SingletonTestRateReceiver} import org.apache.spark.util.{Clock, ManualClock, Utils} /** @@ -391,6 +393,32 @@ class CheckpointSuite extends TestSuiteBase { testCheckpointedOperation(input, operation, output, 7) } + test("recovery maintains rate controller") { + ssc = new StreamingContext(conf, batchDuration) + ssc.checkpoint(checkpointDir) + + val dstream = new RateLimitInputDStream(ssc) { + override val rateController = + Some(new ReceiverRateController(id, new ConstantEstimator(200.0))) + } + SingletonTestRateReceiver.reset() + + val output = new TestOutputStreamWithPartitions(dstream.checkpoint(batchDuration * 2)) + output.register() + runStreams(ssc, 5, 5) + + SingletonTestRateReceiver.reset() + ssc = new StreamingContext(checkpointDir) + ssc.start() + val outputNew = advanceTimeWithRealDelay(ssc, 2) + + eventually(timeout(5.seconds)) { + assert(dstream.getCurrentRateLimit === Some(200)) + } + ssc.stop() + ssc = null + } + // This tests whether file input stream remembers what files were seen before // the master failure and uses them again to process a large window operation. // It also tests whether batches, whose processing was incomplete due to the diff --git a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala index b74d67c63a788..ec2852d9a0206 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala @@ -325,27 +325,31 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { } test("test track the number of input stream") { - val ssc = new StreamingContext(conf, batchDuration) + withStreamingContext(new StreamingContext(conf, batchDuration)) { ssc => - class TestInputDStream extends InputDStream[String](ssc) { - def start() { } - def stop() { } - def compute(validTime: Time): Option[RDD[String]] = None - } + class TestInputDStream extends InputDStream[String](ssc) { + def start() {} - class TestReceiverInputDStream extends ReceiverInputDStream[String](ssc) { - def getReceiver: Receiver[String] = null - } + def stop() {} + + def compute(validTime: Time): Option[RDD[String]] = None + } + + class TestReceiverInputDStream extends ReceiverInputDStream[String](ssc) { + def getReceiver: Receiver[String] = null + } - // Register input streams - val receiverInputStreams = Array(new TestReceiverInputDStream, new TestReceiverInputDStream) - val inputStreams = Array(new TestInputDStream, new TestInputDStream, new TestInputDStream) + // Register input streams + val receiverInputStreams = Array(new TestReceiverInputDStream, new TestReceiverInputDStream) + val inputStreams = Array(new TestInputDStream, new TestInputDStream, new TestInputDStream) - assert(ssc.graph.getInputStreams().length == receiverInputStreams.length + inputStreams.length) - assert(ssc.graph.getReceiverInputStreams().length == receiverInputStreams.length) - assert(ssc.graph.getReceiverInputStreams() === receiverInputStreams) - assert(ssc.graph.getInputStreams().map(_.id) === Array.tabulate(5)(i => i)) - assert(receiverInputStreams.map(_.id) === Array(0, 1)) + assert(ssc.graph.getInputStreams().length == + receiverInputStreams.length + inputStreams.length) + assert(ssc.graph.getReceiverInputStreams().length == receiverInputStreams.length) + assert(ssc.graph.getReceiverInputStreams() === receiverInputStreams) + assert(ssc.graph.getInputStreams().map(_.id) === Array.tabulate(5)(i => i)) + assert(receiverInputStreams.map(_.id) === Array(0, 1)) + } } def testFileStream(newFilesOnly: Boolean) { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/MasterFailureTest.scala b/streaming/src/test/scala/org/apache/spark/streaming/MasterFailureTest.scala index 6e9d4431090a2..0e64b57e0ffd8 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/MasterFailureTest.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/MasterFailureTest.scala @@ -244,7 +244,13 @@ object MasterFailureTest extends Logging { } catch { case e: Exception => logError("Error running streaming context", e) } - if (killingThread.isAlive) killingThread.interrupt() + if (killingThread.isAlive) { + killingThread.interrupt() + // SparkContext.stop will set SparkEnv.env to null. We need to make sure SparkContext is + // stopped before running the next test. Otherwise, it's possible that we set SparkEnv.env + // to null after the next test creates the new SparkContext and fail the test. + killingThread.join() + } ssc.stop() logInfo("Has been killed = " + killed) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala index 4bba9691f8aa5..84a5fbb3d95eb 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala @@ -120,7 +120,7 @@ class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with Timeo val myConf = SparkContext.updatedConf(new SparkConf(false), master, appName) myConf.set("spark.streaming.checkpoint.directory", checkpointDirectory) - val ssc = new StreamingContext(myConf, batchDuration) + ssc = new StreamingContext(myConf, batchDuration) assert(ssc.checkpointDir != null) } @@ -369,16 +369,22 @@ class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with Timeo } assert(exception.isInstanceOf[TestFailedDueToTimeoutException], "Did not wait for stop") + var t: Thread = null // test whether wait exits if context is stopped failAfter(10000 millis) { // 10 seconds because spark takes a long time to shutdown - new Thread() { + t = new Thread() { override def run() { Thread.sleep(500) ssc.stop() } - }.start() + } + t.start() ssc.awaitTermination() } + // SparkContext.stop will set SparkEnv.env to null. We need to make sure SparkContext is stopped + // before running the next test. Otherwise, it's possible that we set SparkEnv.env to null after + // the next test creates the new SparkContext and fail the test. + t.join() } test("awaitTermination after stop") { @@ -430,16 +436,22 @@ class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with Timeo assert(ssc.awaitTerminationOrTimeout(500) === false) } + var t: Thread = null // test whether awaitTerminationOrTimeout() return true if context is stopped failAfter(10000 millis) { // 10 seconds because spark takes a long time to shutdown - new Thread() { + t = new Thread() { override def run() { Thread.sleep(500) ssc.stop() } - }.start() + } + t.start() assert(ssc.awaitTerminationOrTimeout(10000) === true) } + // SparkContext.stop will set SparkEnv.env to null. We need to make sure SparkContext is stopped + // before running the next test. Otherwise, it's possible that we set SparkEnv.env to null after + // the next test creates the new SparkContext and fail the test. + t.join() } test("getOrCreate") { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala index 4bc1dd4a30fc4..d840c349bbbc4 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala @@ -36,13 +36,22 @@ class StreamingListenerSuite extends TestSuiteBase with Matchers { val input = (1 to 4).map(Seq(_)).toSeq val operation = (d: DStream[Int]) => d.map(x => x) + var ssc: StreamingContext = _ + + override def afterFunction() { + super.afterFunction() + if (ssc != null) { + ssc.stop() + } + } + // To make sure that the processing start and end times in collected // information are different for successive batches override def batchDuration: Duration = Milliseconds(100) override def actuallyWait: Boolean = true test("batch info reporting") { - val ssc = setupStreams(input, operation) + ssc = setupStreams(input, operation) val collector = new BatchInfoCollector ssc.addStreamingListener(collector) runStreams(ssc, input.size, input.size) @@ -107,7 +116,7 @@ class StreamingListenerSuite extends TestSuiteBase with Matchers { } test("receiver info reporting") { - val ssc = new StreamingContext("local[2]", "test", Milliseconds(1000)) + ssc = new StreamingContext("local[2]", "test", Milliseconds(1000)) val inputStream = ssc.receiverStream(new StreamingListenerSuiteReceiver) inputStream.foreachRDD(_.count) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/RateControllerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/RateControllerSuite.scala new file mode 100644 index 0000000000000..921da773f6c11 --- /dev/null +++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/RateControllerSuite.scala @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.scheduler + +import scala.collection.mutable +import scala.reflect.ClassTag +import scala.util.control.NonFatal + +import org.scalatest.Matchers._ +import org.scalatest.concurrent.Eventually._ +import org.scalatest.time.SpanSugar._ + +import org.apache.spark.streaming._ +import org.apache.spark.streaming.scheduler.rate.RateEstimator + +class RateControllerSuite extends TestSuiteBase { + + override def useManualClock: Boolean = false + + test("rate controller publishes updates") { + val ssc = new StreamingContext(conf, batchDuration) + withStreamingContext(ssc) { ssc => + val dstream = new RateLimitInputDStream(ssc) + dstream.register() + ssc.start() + + eventually(timeout(10.seconds)) { + assert(dstream.publishCalls > 0) + } + } + } + + test("publish rates reach receivers") { + val ssc = new StreamingContext(conf, batchDuration) + withStreamingContext(ssc) { ssc => + val dstream = new RateLimitInputDStream(ssc) { + override val rateController = + Some(new ReceiverRateController(id, new ConstantEstimator(200.0))) + } + dstream.register() + SingletonTestRateReceiver.reset() + ssc.start() + + eventually(timeout(10.seconds)) { + assert(dstream.getCurrentRateLimit === Some(200)) + } + } + } + + test("multiple publish rates reach receivers") { + val ssc = new StreamingContext(conf, batchDuration) + withStreamingContext(ssc) { ssc => + val rates = Seq(100L, 200L, 300L) + + val dstream = new RateLimitInputDStream(ssc) { + override val rateController = + Some(new ReceiverRateController(id, new ConstantEstimator(rates.map(_.toDouble): _*))) + } + SingletonTestRateReceiver.reset() + dstream.register() + + val observedRates = mutable.HashSet.empty[Long] + ssc.start() + + eventually(timeout(20.seconds)) { + dstream.getCurrentRateLimit.foreach(observedRates += _) + // Long.MaxValue (essentially, no rate limit) is the initial rate limit for any Receiver + observedRates should contain theSameElementsAs (rates :+ Long.MaxValue) + } + } + } +} + +private[streaming] class ConstantEstimator(rates: Double*) extends RateEstimator { + private var idx: Int = 0 + + private def nextRate(): Double = { + val rate = rates(idx) + idx = (idx + 1) % rates.size + rate + } + + def compute( + time: Long, + elements: Long, + processingDelay: Long, + schedulingDelay: Long): Option[Double] = Some(nextRate()) +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicySuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicySuite.scala index 93f920fdc71f1..0418d776ecc9a 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicySuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicySuite.scala @@ -64,7 +64,7 @@ class ReceiverSchedulingPolicySuite extends SparkFunSuite { test("scheduleReceivers: " + "schedule receivers evenly when there are more receivers than executors") { - val receivers = (0 until 6).map(new DummyReceiver(_)) + val receivers = (0 until 6).map(new RateTestReceiver(_)) val executors = (10000 until 10003).map(port => s"localhost:${port}") val scheduledExecutors = receiverSchedulingPolicy.scheduleReceivers(receivers, executors) val numReceiversOnExecutor = mutable.HashMap[String, Int]() @@ -79,7 +79,7 @@ class ReceiverSchedulingPolicySuite extends SparkFunSuite { test("scheduleReceivers: " + "schedule receivers evenly when there are more executors than receivers") { - val receivers = (0 until 3).map(new DummyReceiver(_)) + val receivers = (0 until 3).map(new RateTestReceiver(_)) val executors = (10000 until 10006).map(port => s"localhost:${port}") val scheduledExecutors = receiverSchedulingPolicy.scheduleReceivers(receivers, executors) val numReceiversOnExecutor = mutable.HashMap[String, Int]() @@ -94,8 +94,8 @@ class ReceiverSchedulingPolicySuite extends SparkFunSuite { } test("scheduleReceivers: schedule receivers evenly when the preferredLocations are even") { - val receivers = (0 until 3).map(new DummyReceiver(_)) ++ - (3 until 6).map(new DummyReceiver(_, Some("localhost"))) + val receivers = (0 until 3).map(new RateTestReceiver(_)) ++ + (3 until 6).map(new RateTestReceiver(_, Some("localhost"))) val executors = (10000 until 10003).map(port => s"localhost:${port}") ++ (10003 until 10006).map(port => s"localhost2:${port}") val scheduledExecutors = receiverSchedulingPolicy.scheduleReceivers(receivers, executors) @@ -121,7 +121,7 @@ class ReceiverSchedulingPolicySuite extends SparkFunSuite { } test("scheduleReceivers: return empty scheduled executors if no executors") { - val receivers = (0 until 3).map(new DummyReceiver(_)) + val receivers = (0 until 3).map(new RateTestReceiver(_)) val scheduledExecutors = receiverSchedulingPolicy.scheduleReceivers(receivers, Seq.empty) scheduledExecutors.foreach { case (receiverId, executors) => assert(executors.isEmpty) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala index e2159bd4f225d..afad5f16dbc71 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala @@ -29,69 +29,100 @@ import org.apache.spark.storage.StorageLevel /** Testsuite for receiver scheduling */ class ReceiverTrackerSuite extends TestSuiteBase { val sparkConf = new SparkConf().setMaster("local[8]").setAppName("test") - val ssc = new StreamingContext(sparkConf, Milliseconds(100)) test("Receiver tracker - propagates rate limit") { - object ReceiverStartedWaiter extends StreamingListener { - @volatile - var started = false - - override def onReceiverStarted(receiverStarted: StreamingListenerReceiverStarted): Unit = { - started = true + withStreamingContext(new StreamingContext(sparkConf, Milliseconds(100))) { ssc => + object ReceiverStartedWaiter extends StreamingListener { + @volatile + var started = false + + override def onReceiverStarted(receiverStarted: StreamingListenerReceiverStarted): Unit = { + started = true + } } - } - - ssc.addStreamingListener(ReceiverStartedWaiter) - ssc.scheduler.listenerBus.start(ssc.sc) - - val newRateLimit = 100L - val inputDStream = new RateLimitInputDStream(ssc) - val tracker = new ReceiverTracker(ssc) - tracker.start() - // we wait until the Receiver has registered with the tracker, - // otherwise our rate update is lost - eventually(timeout(5 seconds)) { - assert(ReceiverStartedWaiter.started) - } - tracker.sendRateUpdate(inputDStream.id, newRateLimit) - // this is an async message, we need to wait a bit for it to be processed - eventually(timeout(3 seconds)) { - assert(inputDStream.getCurrentRateLimit.get === newRateLimit) + ssc.addStreamingListener(ReceiverStartedWaiter) + ssc.scheduler.listenerBus.start(ssc.sc) + SingletonTestRateReceiver.reset() + + val newRateLimit = 100L + val inputDStream = new RateLimitInputDStream(ssc) + val tracker = new ReceiverTracker(ssc) + tracker.start() + try { + // we wait until the Receiver has registered with the tracker, + // otherwise our rate update is lost + eventually(timeout(5 seconds)) { + assert(ReceiverStartedWaiter.started) + } + tracker.sendRateUpdate(inputDStream.id, newRateLimit) + // this is an async message, we need to wait a bit for it to be processed + eventually(timeout(3 seconds)) { + assert(inputDStream.getCurrentRateLimit.get === newRateLimit) + } + } finally { + tracker.stop(false) + } } } } -/** An input DStream with a hard-coded receiver that gives access to internals for testing. */ -private class RateLimitInputDStream(@transient ssc_ : StreamingContext) +/** + * An input DStream with a hard-coded receiver that gives access to internals for testing. + * + * @note Make sure to call {{{SingletonDummyReceiver.reset()}}} before using this in a test, + * or otherwise you may get {{{NotSerializableException}}} when trying to serialize + * the receiver. + * @see [[[SingletonDummyReceiver]]]. + */ +private[streaming] class RateLimitInputDStream(@transient ssc_ : StreamingContext) extends ReceiverInputDStream[Int](ssc_) { - override def getReceiver(): DummyReceiver = SingletonDummyReceiver + override def getReceiver(): RateTestReceiver = SingletonTestRateReceiver def getCurrentRateLimit: Option[Long] = { invokeExecutorMethod.getCurrentRateLimit } + @volatile + var publishCalls = 0 + + override val rateController: Option[RateController] = { + Some(new RateController(id, new ConstantEstimator(100.0)) { + override def publish(rate: Long): Unit = { + publishCalls += 1 + } + }) + } + private def invokeExecutorMethod: ReceiverSupervisor = { val c = classOf[Receiver[_]] val ex = c.getDeclaredMethod("executor") ex.setAccessible(true) - ex.invoke(SingletonDummyReceiver).asInstanceOf[ReceiverSupervisor] + ex.invoke(SingletonTestRateReceiver).asInstanceOf[ReceiverSupervisor] } } /** - * A Receiver as an object so we can read its rate limit. + * A Receiver as an object so we can read its rate limit. Make sure to call `reset()` when + * reusing this receiver, otherwise a non-null `executor_` field will prevent it from being + * serialized when receivers are installed on executors. * * @note It's necessary to be a top-level object, or else serialization would create another * one on the executor side and we won't be able to read its rate limit. */ -private object SingletonDummyReceiver extends DummyReceiver(0) +private[streaming] object SingletonTestRateReceiver extends RateTestReceiver(0) { + + /** Reset the object to be usable in another test. */ + def reset(): Unit = { + executor_ = null + } +} /** * Dummy receiver implementation */ -private class DummyReceiver(receiverId: Int, host: Option[String] = None) +private[streaming] class RateTestReceiver(receiverId: Int, host: Option[String] = None) extends Receiver[Int](StorageLevel.MEMORY_ONLY) { setReceiverId(receiverId) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala index 0891309f956d2..995f1197ccdfd 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala @@ -22,15 +22,24 @@ import java.util.Properties import org.scalatest.Matchers import org.apache.spark.scheduler.SparkListenerJobStart +import org.apache.spark.streaming._ import org.apache.spark.streaming.dstream.DStream import org.apache.spark.streaming.scheduler._ -import org.apache.spark.streaming.{Duration, Time, Milliseconds, TestSuiteBase} class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers { val input = (1 to 4).map(Seq(_)).toSeq val operation = (d: DStream[Int]) => d.map(x => x) + var ssc: StreamingContext = _ + + override def afterFunction() { + super.afterFunction() + if (ssc != null) { + ssc.stop() + } + } + private def createJobStart( batchTime: Time, outputOpId: Int, jobId: Int): SparkListenerJobStart = { val properties = new Properties() @@ -46,7 +55,7 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers { test("onBatchSubmitted, onBatchStarted, onBatchCompleted, " + "onReceiverStarted, onReceiverError, onReceiverStopped") { - val ssc = setupStreams(input, operation) + ssc = setupStreams(input, operation) val listener = new StreamingJobProgressListener(ssc) val streamIdToInputInfo = Map( @@ -141,7 +150,7 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers { } test("Remove the old completed batches when exceeding the limit") { - val ssc = setupStreams(input, operation) + ssc = setupStreams(input, operation) val limit = ssc.conf.getInt("spark.streaming.ui.retainedBatches", 1000) val listener = new StreamingJobProgressListener(ssc) @@ -158,7 +167,7 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers { } test("out-of-order onJobStart and onBatchXXX") { - val ssc = setupStreams(input, operation) + ssc = setupStreams(input, operation) val limit = ssc.conf.getInt("spark.streaming.ui.retainedBatches", 1000) val listener = new StreamingJobProgressListener(ssc) @@ -209,7 +218,7 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers { } test("detect memory leak") { - val ssc = setupStreams(input, operation) + ssc = setupStreams(input, operation) val listener = new StreamingJobProgressListener(ssc) val limit = ssc.conf.getInt("spark.streaming.ui.retainedBatches", 1000) diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/unsafe/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java index d0bde69cc1068..198e0684f32f8 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java @@ -74,12 +74,6 @@ public final class BytesToBytesMap { */ private long pageCursor = 0; - /** - * The size of the data pages that hold key and value data. Map entries cannot span multiple - * pages, so this limits the maximum entry size. - */ - private static final long PAGE_SIZE_BYTES = 1L << 26; // 64 megabytes - /** * The maximum number of keys that BytesToBytesMap supports. The hash table has to be * power-of-2-sized and its backing Java array can contain at most (1 << 30) elements, since @@ -117,6 +111,12 @@ public final class BytesToBytesMap { private final double loadFactor; + /** + * The size of the data pages that hold key and value data. Map entries cannot span multiple + * pages, so this limits the maximum entry size. + */ + private final long pageSizeBytes; + /** * Number of keys defined in the map. */ @@ -153,10 +153,12 @@ public BytesToBytesMap( TaskMemoryManager memoryManager, int initialCapacity, double loadFactor, + long pageSizeBytes, boolean enablePerfMetrics) { this.memoryManager = memoryManager; this.loadFactor = loadFactor; this.loc = new Location(); + this.pageSizeBytes = pageSizeBytes; this.enablePerfMetrics = enablePerfMetrics; if (initialCapacity <= 0) { throw new IllegalArgumentException("Initial capacity must be greater than 0"); @@ -165,18 +167,26 @@ public BytesToBytesMap( throw new IllegalArgumentException( "Initial capacity " + initialCapacity + " exceeds maximum capacity of " + MAX_CAPACITY); } + if (pageSizeBytes > TaskMemoryManager.MAXIMUM_PAGE_SIZE_BYTES) { + throw new IllegalArgumentException("Page size " + pageSizeBytes + " cannot exceed " + + TaskMemoryManager.MAXIMUM_PAGE_SIZE_BYTES); + } allocate(initialCapacity); } - public BytesToBytesMap(TaskMemoryManager memoryManager, int initialCapacity) { - this(memoryManager, initialCapacity, 0.70, false); + public BytesToBytesMap( + TaskMemoryManager memoryManager, + int initialCapacity, + long pageSizeBytes) { + this(memoryManager, initialCapacity, 0.70, pageSizeBytes, false); } public BytesToBytesMap( TaskMemoryManager memoryManager, int initialCapacity, + long pageSizeBytes, boolean enablePerfMetrics) { - this(memoryManager, initialCapacity, 0.70, enablePerfMetrics); + this(memoryManager, initialCapacity, 0.70, pageSizeBytes, enablePerfMetrics); } /** @@ -443,20 +453,20 @@ public void putNewKey( // must be stored in the same memory page. // (8 byte key length) (key) (8 byte value length) (value) final long requiredSize = 8 + keyLengthBytes + 8 + valueLengthBytes; - assert (requiredSize <= PAGE_SIZE_BYTES - 8); // Reserve 8 bytes for the end-of-page marker. + assert (requiredSize <= pageSizeBytes - 8); // Reserve 8 bytes for the end-of-page marker. size++; bitset.set(pos); // If there's not enough space in the current page, allocate a new page (8 bytes are reserved // for the end-of-page marker). - if (currentDataPage == null || PAGE_SIZE_BYTES - 8 - pageCursor < requiredSize) { + if (currentDataPage == null || pageSizeBytes - 8 - pageCursor < requiredSize) { if (currentDataPage != null) { // There wasn't enough space in the current page, so write an end-of-page marker: final Object pageBaseObject = currentDataPage.getBaseObject(); final long lengthOffsetInPage = currentDataPage.getBaseOffset() + pageCursor; PlatformDependent.UNSAFE.putLong(pageBaseObject, lengthOffsetInPage, END_OF_PAGE_MARKER); } - MemoryBlock newPage = memoryManager.allocatePage(PAGE_SIZE_BYTES); + MemoryBlock newPage = memoryManager.allocatePage(pageSizeBytes); dataPages.add(newPage); pageCursor = 0; currentDataPage = newPage; @@ -538,10 +548,11 @@ public void free() { /** Returns the total amount of memory, in bytes, consumed by this map's managed structures. */ public long getTotalMemoryConsumption() { - return ( - dataPages.size() * PAGE_SIZE_BYTES + - bitset.memoryBlock().size() + - longArray.memoryBlock().size()); + long totalDataPagesSize = 0L; + for (MemoryBlock dataPage : dataPages) { + totalDataPagesSize += dataPage.size(); + } + return totalDataPagesSize + bitset.memoryBlock().size() + longArray.memoryBlock().size(); } /** diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java b/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java index 10881969dbc78..dd70df3b1f791 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java @@ -58,8 +58,13 @@ public class TaskMemoryManager { /** The number of entries in the page table. */ private static final int PAGE_TABLE_SIZE = 1 << PAGE_NUMBER_BITS; - /** Maximum supported data page size */ - private static final long MAXIMUM_PAGE_SIZE = (1L << OFFSET_BITS); + /** + * Maximum supported data page size (in bytes). In principle, the maximum addressable page size is + * (1L << OFFSET_BITS) bytes, which is 2+ petabytes. However, the on-heap allocator's maximum page + * size is limited by the maximum amount of data that can be stored in a long[] array, which is + * (2^32 - 1) * 8 bytes (or 16 gigabytes). Therefore, we cap this at 16 gigabytes. + */ + public static final long MAXIMUM_PAGE_SIZE_BYTES = ((1L << 31) - 1) * 8L; /** Bit mask for the lower 51 bits of a long. */ private static final long MASK_LONG_LOWER_51_BITS = 0x7FFFFFFFFFFFFL; @@ -110,9 +115,9 @@ public TaskMemoryManager(ExecutorMemoryManager executorMemoryManager) { * intended for allocating large blocks of memory that will be shared between operators. */ public MemoryBlock allocatePage(long size) { - if (size > MAXIMUM_PAGE_SIZE) { + if (size > MAXIMUM_PAGE_SIZE_BYTES) { throw new IllegalArgumentException( - "Cannot allocate a page with more than " + MAXIMUM_PAGE_SIZE + " bytes"); + "Cannot allocate a page with more than " + MAXIMUM_PAGE_SIZE_BYTES + " bytes"); } final int pageNumber; diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/Interval.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java similarity index 87% rename from unsafe/src/main/java/org/apache/spark/unsafe/types/Interval.java rename to unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java index 71b1a85a818ea..92a5e4f86f234 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/types/Interval.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java @@ -24,7 +24,7 @@ /** * The internal representation of interval type. */ -public final class Interval implements Serializable { +public final class CalendarInterval implements Serializable { public static final long MICROS_PER_MILLI = 1000L; public static final long MICROS_PER_SECOND = MICROS_PER_MILLI * 1000; public static final long MICROS_PER_MINUTE = MICROS_PER_SECOND * 60; @@ -58,7 +58,7 @@ private static long toLong(String s) { } } - public static Interval fromString(String s) { + public static CalendarInterval fromString(String s) { if (s == null) { return null; } @@ -75,40 +75,40 @@ public static Interval fromString(String s) { microseconds += toLong(m.group(7)) * MICROS_PER_SECOND; microseconds += toLong(m.group(8)) * MICROS_PER_MILLI; microseconds += toLong(m.group(9)); - return new Interval((int) months, microseconds); + return new CalendarInterval((int) months, microseconds); } } public final int months; public final long microseconds; - public Interval(int months, long microseconds) { + public CalendarInterval(int months, long microseconds) { this.months = months; this.microseconds = microseconds; } - public Interval add(Interval that) { + public CalendarInterval add(CalendarInterval that) { int months = this.months + that.months; long microseconds = this.microseconds + that.microseconds; - return new Interval(months, microseconds); + return new CalendarInterval(months, microseconds); } - public Interval subtract(Interval that) { + public CalendarInterval subtract(CalendarInterval that) { int months = this.months - that.months; long microseconds = this.microseconds - that.microseconds; - return new Interval(months, microseconds); + return new CalendarInterval(months, microseconds); } - public Interval negate() { - return new Interval(-this.months, -this.microseconds); + public CalendarInterval negate() { + return new CalendarInterval(-this.months, -this.microseconds); } @Override public boolean equals(Object other) { if (this == other) return true; - if (other == null || !(other instanceof Interval)) return false; + if (other == null || !(other instanceof CalendarInterval)) return false; - Interval o = (Interval) other; + CalendarInterval o = (CalendarInterval) other; return this.months == o.months && this.microseconds == o.microseconds; } diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index 85381cf0ef425..c38953f65d7d7 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -65,6 +65,19 @@ public static UTF8String fromBytes(byte[] bytes) { } } + /** + * Creates an UTF8String from byte array, which should be encoded in UTF-8. + * + * Note: `bytes` will be hold by returned UTF8String. + */ + public static UTF8String fromBytes(byte[] bytes, int offset, int numBytes) { + if (bytes != null) { + return new UTF8String(bytes, BYTE_ARRAY_OFFSET + offset, numBytes); + } else { + return null; + } + } + /** * Creates an UTF8String from String. */ @@ -89,10 +102,10 @@ public static UTF8String blankString(int length) { return fromBytes(spaces); } - protected UTF8String(Object base, long offset, int size) { + protected UTF8String(Object base, long offset, int numBytes) { this.base = base; this.offset = offset; - this.numBytes = size; + this.numBytes = numBytes; } /** @@ -137,6 +150,32 @@ public int numChars() { return len; } + /** + * Returns a 64-bit integer that can be used as the prefix used in sorting. + */ + public long getPrefix() { + // Since JVMs are either 4-byte aligned or 8-byte aligned, we check the size of the string. + // If size is 0, just return 0. + // If size is between 0 and 4 (inclusive), assume data is 4-byte aligned under the hood and + // use a getInt to fetch the prefix. + // If size is greater than 4, assume we have at least 8 bytes of data to fetch. + // After getting the data, we use a mask to mask out data that is not part of the string. + long p; + if (numBytes >= 8) { + p = PlatformDependent.UNSAFE.getLong(base, offset); + } else if (numBytes > 4) { + p = PlatformDependent.UNSAFE.getLong(base, offset); + p = p & ((1L << numBytes * 8) - 1); + } else if (numBytes > 0) { + p = (long) PlatformDependent.UNSAFE.getInt(base, offset); + p = p & ((1L << numBytes * 8) - 1); + } else { + p = 0; + } + p = java.lang.Long.reverseBytes(p); + return p; + } + /** * Returns the underline bytes, will be a copy of it if it's part of another array. */ @@ -300,13 +339,13 @@ public UTF8String trimRight() { } public UTF8String reverse() { - byte[] bytes = getBytes(); - byte[] result = new byte[bytes.length]; + byte[] result = new byte[this.numBytes]; int i = 0; // position in byte while (i < numBytes) { int len = numBytesForFirstByte(getByte(i)); - System.arraycopy(bytes, i, result, result.length - i - len, len); + copyMemory(this.base, this.offset + i, result, + BYTE_ARRAY_OFFSET + result.length - i - len, len); i += len; } @@ -316,11 +355,11 @@ public UTF8String reverse() { public UTF8String repeat(int times) { if (times <=0) { - return fromBytes(new byte[0]); + return EMPTY_UTF8; } byte[] newBytes = new byte[numBytes * times]; - System.arraycopy(getBytes(), 0, newBytes, 0, numBytes); + copyMemory(this.base, this.offset, newBytes, BYTE_ARRAY_OFFSET, numBytes); int copied = 1; while (copied < times) { @@ -385,16 +424,15 @@ public UTF8String rpad(int len, UTF8String pad) { UTF8String remain = pad.substring(0, spaces - padChars * count); byte[] data = new byte[this.numBytes + pad.numBytes * count + remain.numBytes]; - System.arraycopy(getBytes(), 0, data, 0, this.numBytes); + copyMemory(this.base, this.offset, data, BYTE_ARRAY_OFFSET, this.numBytes); int offset = this.numBytes; int idx = 0; - byte[] padBytes = pad.getBytes(); while (idx < count) { - System.arraycopy(padBytes, 0, data, offset, pad.numBytes); + copyMemory(pad.base, pad.offset, data, BYTE_ARRAY_OFFSET + offset, pad.numBytes); ++idx; offset += pad.numBytes; } - System.arraycopy(remain.getBytes(), 0, data, offset, remain.numBytes); + copyMemory(remain.base, remain.offset, data, BYTE_ARRAY_OFFSET + offset, remain.numBytes); return UTF8String.fromBytes(data); } @@ -421,15 +459,14 @@ public UTF8String lpad(int len, UTF8String pad) { int offset = 0; int idx = 0; - byte[] padBytes = pad.getBytes(); while (idx < count) { - System.arraycopy(padBytes, 0, data, offset, pad.numBytes); + copyMemory(pad.base, pad.offset, data, BYTE_ARRAY_OFFSET + offset, pad.numBytes); ++idx; offset += pad.numBytes; } - System.arraycopy(remain.getBytes(), 0, data, offset, remain.numBytes); + copyMemory(remain.base, remain.offset, data, BYTE_ARRAY_OFFSET + offset, remain.numBytes); offset += remain.numBytes; - System.arraycopy(getBytes(), 0, data, offset, numBytes()); + copyMemory(this.base, this.offset, data, BYTE_ARRAY_OFFSET + offset, numBytes()); return UTF8String.fromBytes(data); } @@ -454,9 +491,9 @@ public static UTF8String concat(UTF8String... inputs) { int offset = 0; for (int i = 0; i < inputs.length; i++) { int len = inputs[i].numBytes; - PlatformDependent.copyMemory( + copyMemory( inputs[i].base, inputs[i].offset, - result, PlatformDependent.BYTE_ARRAY_OFFSET + offset, + result, BYTE_ARRAY_OFFSET + offset, len); offset += len; } @@ -494,7 +531,7 @@ public static UTF8String concatWs(UTF8String separator, UTF8String... inputs) { for (int i = 0, j = 0; i < inputs.length; i++) { if (inputs[i] != null) { int len = inputs[i].numBytes; - PlatformDependent.copyMemory( + copyMemory( inputs[i].base, inputs[i].offset, result, PlatformDependent.BYTE_ARRAY_OFFSET + offset, len); @@ -503,7 +540,7 @@ public static UTF8String concatWs(UTF8String separator, UTF8String... inputs) { j++; // Add separator if this is not the last input. if (j < numInputs) { - PlatformDependent.copyMemory( + copyMemory( separator.base, separator.offset, result, PlatformDependent.BYTE_ARRAY_OFFSET + offset, separator.numBytes); diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java index dae47e4bab0cb..0be94ad371255 100644 --- a/unsafe/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java +++ b/unsafe/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java @@ -43,6 +43,7 @@ public abstract class AbstractBytesToBytesMapSuite { private TaskMemoryManager memoryManager; private TaskMemoryManager sizeLimitedMemoryManager; + private final long PAGE_SIZE_BYTES = 1L << 26; // 64 megabytes @Before public void setup() { @@ -110,7 +111,7 @@ private static boolean arrayEquals( @Test public void emptyMap() { - BytesToBytesMap map = new BytesToBytesMap(memoryManager, 64); + BytesToBytesMap map = new BytesToBytesMap(memoryManager, 64, PAGE_SIZE_BYTES); try { Assert.assertEquals(0, map.size()); final int keyLengthInWords = 10; @@ -125,7 +126,7 @@ public void emptyMap() { @Test public void setAndRetrieveAKey() { - BytesToBytesMap map = new BytesToBytesMap(memoryManager, 64); + BytesToBytesMap map = new BytesToBytesMap(memoryManager, 64, PAGE_SIZE_BYTES); final int recordLengthWords = 10; final int recordLengthBytes = recordLengthWords * 8; final byte[] keyData = getRandomByteArray(recordLengthWords); @@ -177,7 +178,7 @@ public void setAndRetrieveAKey() { @Test public void iteratorTest() throws Exception { final int size = 4096; - BytesToBytesMap map = new BytesToBytesMap(memoryManager, size / 2); + BytesToBytesMap map = new BytesToBytesMap(memoryManager, size / 2, PAGE_SIZE_BYTES); try { for (long i = 0; i < size; i++) { final long[] value = new long[] { i }; @@ -235,7 +236,7 @@ public void iteratingOverDataPagesWithWastedSpace() throws Exception { final int NUM_ENTRIES = 1000 * 1000; final int KEY_LENGTH = 16; final int VALUE_LENGTH = 40; - final BytesToBytesMap map = new BytesToBytesMap(memoryManager, NUM_ENTRIES); + final BytesToBytesMap map = new BytesToBytesMap(memoryManager, NUM_ENTRIES, PAGE_SIZE_BYTES); // Each record will take 8 + 8 + 16 + 40 = 72 bytes of space in the data page. Our 64-megabyte // pages won't be evenly-divisible by records of this size, which will cause us to waste some // space at the end of the page. This is necessary in order for us to take the end-of-record @@ -304,7 +305,7 @@ public void randomizedStressTest() { // Java arrays' hashCodes() aren't based on the arrays' contents, so we need to wrap arrays // into ByteBuffers in order to use them as keys here. final Map expected = new HashMap(); - final BytesToBytesMap map = new BytesToBytesMap(memoryManager, size); + final BytesToBytesMap map = new BytesToBytesMap(memoryManager, size, PAGE_SIZE_BYTES); try { // Fill the map to 90% full so that we can trigger probing @@ -353,14 +354,15 @@ public void randomizedStressTest() { @Test public void initialCapacityBoundsChecking() { try { - new BytesToBytesMap(sizeLimitedMemoryManager, 0); + new BytesToBytesMap(sizeLimitedMemoryManager, 0, PAGE_SIZE_BYTES); Assert.fail("Expected IllegalArgumentException to be thrown"); } catch (IllegalArgumentException e) { // expected exception } try { - new BytesToBytesMap(sizeLimitedMemoryManager, BytesToBytesMap.MAX_CAPACITY + 1); + new BytesToBytesMap( + sizeLimitedMemoryManager, BytesToBytesMap.MAX_CAPACITY + 1, PAGE_SIZE_BYTES); Assert.fail("Expected IllegalArgumentException to be thrown"); } catch (IllegalArgumentException e) { // expected exception @@ -368,15 +370,15 @@ public void initialCapacityBoundsChecking() { // Can allocate _at_ the max capacity BytesToBytesMap map = - new BytesToBytesMap(sizeLimitedMemoryManager, BytesToBytesMap.MAX_CAPACITY); + new BytesToBytesMap(sizeLimitedMemoryManager, BytesToBytesMap.MAX_CAPACITY, PAGE_SIZE_BYTES); map.free(); } @Test public void resizingLargeMap() { // As long as a map's capacity is below the max, we should be able to resize up to the max - BytesToBytesMap map = - new BytesToBytesMap(sizeLimitedMemoryManager, BytesToBytesMap.MAX_CAPACITY - 64); + BytesToBytesMap map = new BytesToBytesMap( + sizeLimitedMemoryManager, BytesToBytesMap.MAX_CAPACITY - 64, PAGE_SIZE_BYTES); map.growAndRehash(); map.free(); } diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/types/IntervalSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/types/IntervalSuite.java index d29517cda66a3..e6733a7aae6f5 100644 --- a/unsafe/src/test/java/org/apache/spark/unsafe/types/IntervalSuite.java +++ b/unsafe/src/test/java/org/apache/spark/unsafe/types/IntervalSuite.java @@ -20,16 +20,16 @@ import org.junit.Test; import static junit.framework.Assert.*; -import static org.apache.spark.unsafe.types.Interval.*; +import static org.apache.spark.unsafe.types.CalendarInterval.*; public class IntervalSuite { @Test public void equalsTest() { - Interval i1 = new Interval(3, 123); - Interval i2 = new Interval(3, 321); - Interval i3 = new Interval(1, 123); - Interval i4 = new Interval(3, 123); + CalendarInterval i1 = new CalendarInterval(3, 123); + CalendarInterval i2 = new CalendarInterval(3, 321); + CalendarInterval i3 = new CalendarInterval(1, 123); + CalendarInterval i4 = new CalendarInterval(3, 123); assertNotSame(i1, i2); assertNotSame(i1, i3); @@ -39,21 +39,21 @@ public void equalsTest() { @Test public void toStringTest() { - Interval i; + CalendarInterval i; - i = new Interval(34, 0); + i = new CalendarInterval(34, 0); assertEquals(i.toString(), "interval 2 years 10 months"); - i = new Interval(-34, 0); + i = new CalendarInterval(-34, 0); assertEquals(i.toString(), "interval -2 years -10 months"); - i = new Interval(0, 3 * MICROS_PER_WEEK + 13 * MICROS_PER_HOUR + 123); + i = new CalendarInterval(0, 3 * MICROS_PER_WEEK + 13 * MICROS_PER_HOUR + 123); assertEquals(i.toString(), "interval 3 weeks 13 hours 123 microseconds"); - i = new Interval(0, -3 * MICROS_PER_WEEK - 13 * MICROS_PER_HOUR - 123); + i = new CalendarInterval(0, -3 * MICROS_PER_WEEK - 13 * MICROS_PER_HOUR - 123); assertEquals(i.toString(), "interval -3 weeks -13 hours -123 microseconds"); - i = new Interval(34, 3 * MICROS_PER_WEEK + 13 * MICROS_PER_HOUR + 123); + i = new CalendarInterval(34, 3 * MICROS_PER_WEEK + 13 * MICROS_PER_HOUR + 123); assertEquals(i.toString(), "interval 2 years 10 months 3 weeks 13 hours 123 microseconds"); } @@ -72,33 +72,33 @@ public void fromStringTest() { String input; input = "interval -5 years 23 month"; - Interval result = new Interval(-5 * 12 + 23, 0); - assertEquals(Interval.fromString(input), result); + CalendarInterval result = new CalendarInterval(-5 * 12 + 23, 0); + assertEquals(CalendarInterval.fromString(input), result); input = "interval -5 years 23 month "; - assertEquals(Interval.fromString(input), result); + assertEquals(CalendarInterval.fromString(input), result); input = " interval -5 years 23 month "; - assertEquals(Interval.fromString(input), result); + assertEquals(CalendarInterval.fromString(input), result); // Error cases input = "interval 3month 1 hour"; - assertEquals(Interval.fromString(input), null); + assertEquals(CalendarInterval.fromString(input), null); input = "interval 3 moth 1 hour"; - assertEquals(Interval.fromString(input), null); + assertEquals(CalendarInterval.fromString(input), null); input = "interval"; - assertEquals(Interval.fromString(input), null); + assertEquals(CalendarInterval.fromString(input), null); input = "int"; - assertEquals(Interval.fromString(input), null); + assertEquals(CalendarInterval.fromString(input), null); input = ""; - assertEquals(Interval.fromString(input), null); + assertEquals(CalendarInterval.fromString(input), null); input = null; - assertEquals(Interval.fromString(input), null); + assertEquals(CalendarInterval.fromString(input), null); } @Test @@ -106,18 +106,18 @@ public void addTest() { String input = "interval 3 month 1 hour"; String input2 = "interval 2 month 100 hour"; - Interval interval = Interval.fromString(input); - Interval interval2 = Interval.fromString(input2); + CalendarInterval interval = CalendarInterval.fromString(input); + CalendarInterval interval2 = CalendarInterval.fromString(input2); - assertEquals(interval.add(interval2), new Interval(5, 101 * MICROS_PER_HOUR)); + assertEquals(interval.add(interval2), new CalendarInterval(5, 101 * MICROS_PER_HOUR)); input = "interval -10 month -81 hour"; input2 = "interval 75 month 200 hour"; - interval = Interval.fromString(input); - interval2 = Interval.fromString(input2); + interval = CalendarInterval.fromString(input); + interval2 = CalendarInterval.fromString(input2); - assertEquals(interval.add(interval2), new Interval(65, 119 * MICROS_PER_HOUR)); + assertEquals(interval.add(interval2), new CalendarInterval(65, 119 * MICROS_PER_HOUR)); } @Test @@ -125,25 +125,25 @@ public void subtractTest() { String input = "interval 3 month 1 hour"; String input2 = "interval 2 month 100 hour"; - Interval interval = Interval.fromString(input); - Interval interval2 = Interval.fromString(input2); + CalendarInterval interval = CalendarInterval.fromString(input); + CalendarInterval interval2 = CalendarInterval.fromString(input2); - assertEquals(interval.subtract(interval2), new Interval(1, -99 * MICROS_PER_HOUR)); + assertEquals(interval.subtract(interval2), new CalendarInterval(1, -99 * MICROS_PER_HOUR)); input = "interval -10 month -81 hour"; input2 = "interval 75 month 200 hour"; - interval = Interval.fromString(input); - interval2 = Interval.fromString(input2); + interval = CalendarInterval.fromString(input); + interval2 = CalendarInterval.fromString(input2); - assertEquals(interval.subtract(interval2), new Interval(-85, -281 * MICROS_PER_HOUR)); + assertEquals(interval.subtract(interval2), new CalendarInterval(-85, -281 * MICROS_PER_HOUR)); } private void testSingleUnit(String unit, int number, int months, long microseconds) { String input1 = "interval " + number + " " + unit; String input2 = "interval " + number + " " + unit + "s"; - Interval result = new Interval(months, microseconds); - assertEquals(Interval.fromString(input1), result); - assertEquals(Interval.fromString(input2), result); + CalendarInterval result = new CalendarInterval(months, microseconds); + assertEquals(CalendarInterval.fromString(input1), result); + assertEquals(CalendarInterval.fromString(input2), result); } } diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java index e2a5628ff4d93..f2cc19ca6b172 100644 --- a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java +++ b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java @@ -63,8 +63,27 @@ public void emptyStringTest() { assertEquals(0, EMPTY_UTF8.numBytes()); } + @Test + public void prefix() { + assertTrue(fromString("a").getPrefix() - fromString("b").getPrefix() < 0); + assertTrue(fromString("ab").getPrefix() - fromString("b").getPrefix() < 0); + assertTrue( + fromString("abbbbbbbbbbbasdf").getPrefix() - fromString("bbbbbbbbbbbbasdf").getPrefix() < 0); + assertTrue(fromString("").getPrefix() - fromString("a").getPrefix() < 0); + assertTrue(fromString("你好").getPrefix() - fromString("世界").getPrefix() > 0); + + byte[] buf1 = {1, 2, 3, 4, 5, 6, 7, 8, 9}; + byte[] buf2 = {1, 2, 3}; + UTF8String str1 = UTF8String.fromBytes(buf1, 0, 3); + UTF8String str2 = UTF8String.fromBytes(buf1, 0, 8); + UTF8String str3 = UTF8String.fromBytes(buf2); + assertTrue(str1.getPrefix() - str2.getPrefix() < 0); + assertEquals(str1.getPrefix(), str3.getPrefix()); + } + @Test public void compareTo() { + assertTrue(fromString("").compareTo(fromString("a")) < 0); assertTrue(fromString("abc").compareTo(fromString("ABC")) > 0); assertTrue(fromString("abc0").compareTo(fromString("abc")) > 0); assertTrue(fromString("abcabcabc").compareTo(fromString("abcabcabc")) == 0); diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 44acc7374d024..1d67b3ebb51b7 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -229,7 +229,11 @@ private[spark] class ApplicationMaster( sparkContextRef.compareAndSet(sc, null) } - private def registerAM(_rpcEnv: RpcEnv, uiAddress: String, securityMgr: SecurityManager) = { + private def registerAM( + _rpcEnv: RpcEnv, + driverRef: RpcEndpointRef, + uiAddress: String, + securityMgr: SecurityManager) = { val sc = sparkContextRef.get() val appId = client.getAttemptId().getApplicationId().toString() @@ -246,6 +250,7 @@ private[spark] class ApplicationMaster( RpcAddress(_sparkConf.get("spark.driver.host"), _sparkConf.get("spark.driver.port").toInt), CoarseGrainedSchedulerBackend.ENDPOINT_NAME) allocator = client.register(driverUrl, + driverRef, yarnConf, _sparkConf, if (sc != null) sc.preferredNodeLocationData else Map(), @@ -262,17 +267,20 @@ private[spark] class ApplicationMaster( * * In cluster mode, the AM and the driver belong to same process * so the AMEndpoint need not monitor lifecycle of the driver. + * + * @return A reference to the driver's RPC endpoint. */ private def runAMEndpoint( host: String, port: String, - isClusterMode: Boolean): Unit = { + isClusterMode: Boolean): RpcEndpointRef = { val driverEndpoint = rpcEnv.setupEndpointRef( SparkEnv.driverActorSystemName, RpcAddress(host, port.toInt), YarnSchedulerBackend.ENDPOINT_NAME) amEndpoint = rpcEnv.setupEndpoint("YarnAM", new AMEndpoint(rpcEnv, driverEndpoint, isClusterMode)) + driverEndpoint } private def runDriver(securityMgr: SecurityManager): Unit = { @@ -290,11 +298,11 @@ private[spark] class ApplicationMaster( "Timed out waiting for SparkContext.") } else { rpcEnv = sc.env.rpcEnv - runAMEndpoint( + val driverRef = runAMEndpoint( sc.getConf.get("spark.driver.host"), sc.getConf.get("spark.driver.port"), isClusterMode = true) - registerAM(rpcEnv, sc.ui.map(_.appUIAddress).getOrElse(""), securityMgr) + registerAM(rpcEnv, driverRef, sc.ui.map(_.appUIAddress).getOrElse(""), securityMgr) userClassThread.join() } } @@ -302,9 +310,9 @@ private[spark] class ApplicationMaster( private def runExecutorLauncher(securityMgr: SecurityManager): Unit = { val port = sparkConf.getInt("spark.yarn.am.port", 0) rpcEnv = RpcEnv.create("sparkYarnAM", Utils.localHostName, port, sparkConf, securityMgr) - waitForSparkDriver() + val driverRef = waitForSparkDriver() addAmIpFilter() - registerAM(rpcEnv, sparkConf.get("spark.driver.appUIAddress", ""), securityMgr) + registerAM(rpcEnv, driverRef, sparkConf.get("spark.driver.appUIAddress", ""), securityMgr) // In client mode the actor will stop the reporter thread. reporterThread.join() @@ -428,7 +436,7 @@ private[spark] class ApplicationMaster( } } - private def waitForSparkDriver(): Unit = { + private def waitForSparkDriver(): RpcEndpointRef = { logInfo("Waiting for Spark driver to be reachable.") var driverUp = false val hostport = args.userArgs(0) diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index bc28ce5eeae72..4ac3397f1ad28 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -767,7 +767,7 @@ private[spark] class Client( amContainer.setCommands(printableCommands) logDebug("===============================================================================") - logDebug("Yarn AM launch context:") + logDebug("YARN AM launch context:") logDebug(s" user class: ${Option(args.userClass).getOrElse("N/A")}") logDebug(" env:") launchEnv.foreach { case (k, v) => logDebug(s" $k -> $v") } diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala index 78e27fb7f3337..52580deb372c2 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala @@ -86,10 +86,17 @@ class ExecutorRunnable( val commands = prepareCommand(masterAddress, slaveId, hostname, executorMemory, executorCores, appId, localResources) - logInfo(s"Setting up executor with environment: $env") - logInfo("Setting up executor with commands: " + commands) - ctx.setCommands(commands) + logInfo(s""" + |=============================================================================== + |YARN executor launch context: + | env: + |${env.map { case (k, v) => s" $k -> $v\n" }.mkString} + | command: + | ${commands.mkString(" ")} + |=============================================================================== + """.stripMargin) + ctx.setCommands(commands) ctx.setApplicationACLs(YarnSparkHadoopUtil.getApplicationAclsForYarn(securityMgr)) // If external shuffle service is enabled, register with the Yarn shuffle service already diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala index 6c103394af098..59caa787b6e20 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala @@ -36,6 +36,9 @@ import org.apache.log4j.{Level, Logger} import org.apache.spark.{Logging, SecurityManager, SparkConf} import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil._ +import org.apache.spark.rpc.RpcEndpointRef +import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend +import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ /** * YarnAllocator is charged with requesting containers from the YARN ResourceManager and deciding @@ -52,6 +55,7 @@ import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil._ */ private[yarn] class YarnAllocator( driverUrl: String, + driverRef: RpcEndpointRef, conf: Configuration, sparkConf: SparkConf, amClient: AMRMClient[ContainerRequest], @@ -88,6 +92,9 @@ private[yarn] class YarnAllocator( // Visible for testing. private[yarn] val executorIdToContainer = new HashMap[String, Container] + private var numUnexpectedContainerRelease = 0L + private val containerIdToExecutorId = new HashMap[ContainerId, String] + // Executor memory in MB. protected val executorMemory = args.executorMemory // Additional memory overhead. @@ -184,6 +191,7 @@ private[yarn] class YarnAllocator( def killExecutor(executorId: String): Unit = synchronized { if (executorIdToContainer.contains(executorId)) { val container = executorIdToContainer.remove(executorId).get + containerIdToExecutorId.remove(container.getId) internalReleaseContainer(container) numExecutorsRunning -= 1 } else { @@ -383,6 +391,7 @@ private[yarn] class YarnAllocator( logInfo("Launching container %s for on host %s".format(containerId, executorHostname)) executorIdToContainer(executorId) = container + containerIdToExecutorId(container.getId) = executorId val containerSet = allocatedHostToContainersMap.getOrElseUpdate(executorHostname, new HashSet[ContainerId]) @@ -413,12 +422,8 @@ private[yarn] class YarnAllocator( private[yarn] def processCompletedContainers(completedContainers: Seq[ContainerStatus]): Unit = { for (completedContainer <- completedContainers) { val containerId = completedContainer.getContainerId - - if (releasedContainers.contains(containerId)) { - // Already marked the container for release, so remove it from - // `releasedContainers`. - releasedContainers.remove(containerId) - } else { + val alreadyReleased = releasedContainers.remove(containerId) + if (!alreadyReleased) { // Decrement the number of executors running. The next iteration of // the ApplicationMaster's reporting thread will take care of allocating. numExecutorsRunning -= 1 @@ -460,6 +465,18 @@ private[yarn] class YarnAllocator( allocatedContainerToHostMap.remove(containerId) } + + containerIdToExecutorId.remove(containerId).foreach { eid => + executorIdToContainer.remove(eid) + + if (!alreadyReleased) { + // The executor could have gone away (like no route to host, node failure, etc) + // Notify backend about the failure of the executor + numUnexpectedContainerRelease += 1 + driverRef.send(RemoveExecutor(eid, + s"Yarn deallocated the executor $eid (container $containerId)")) + } + } } } @@ -467,6 +484,9 @@ private[yarn] class YarnAllocator( releasedContainers.add(container.getId()) amClient.releaseAssignedContainer(container.getId()) } + + private[yarn] def getNumUnexpectedContainerRelease = numUnexpectedContainerRelease + } private object YarnAllocator { diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala index 7f533ee55e8bb..4999f9c06210a 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala @@ -33,6 +33,7 @@ import org.apache.hadoop.yarn.util.ConverterUtils import org.apache.hadoop.yarn.webapp.util.WebAppUtils import org.apache.spark.{Logging, SecurityManager, SparkConf} +import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.scheduler.SplitInfo import org.apache.spark.util.Utils @@ -56,6 +57,7 @@ private[spark] class YarnRMClient(args: ApplicationMasterArguments) extends Logg */ def register( driverUrl: String, + driverRef: RpcEndpointRef, conf: YarnConfiguration, sparkConf: SparkConf, preferredNodeLocations: Map[String, Set[SplitInfo]], @@ -73,7 +75,8 @@ private[spark] class YarnRMClient(args: ApplicationMasterArguments) extends Logg amClient.registerApplicationMaster(Utils.localHostName(), 0, uiAddress) registered = true } - new YarnAllocator(driverUrl, conf, sparkConf, amClient, getAttemptId(), args, securityMgr) + new YarnAllocator(driverUrl, driverRef, conf, sparkConf, amClient, getAttemptId(), args, + securityMgr) } /** diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala index 37a789fcd375b..58318bf9bcc08 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala @@ -27,10 +27,14 @@ import org.apache.hadoop.yarn.client.api.AMRMClient import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest import org.scalatest.{BeforeAndAfterEach, Matchers} +import org.scalatest.{BeforeAndAfterEach, Matchers} +import org.mockito.Mockito._ + import org.apache.spark.{SecurityManager, SparkFunSuite} import org.apache.spark.SparkConf import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil._ import org.apache.spark.deploy.yarn.YarnAllocator._ +import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.scheduler.SplitInfo class MockResolver extends DNSToSwitchMapping { @@ -90,6 +94,7 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter "--class", "SomeClass") new YarnAllocator( "not used", + mock(classOf[RpcEndpointRef]), conf, sparkConf, rmClient, @@ -230,6 +235,30 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter handler.getNumPendingAllocate should be (1) } + test("lost executor removed from backend") { + val handler = createAllocator(4) + handler.updateResourceRequests() + handler.getNumExecutorsRunning should be (0) + handler.getNumPendingAllocate should be (4) + + val container1 = createContainer("host1") + val container2 = createContainer("host2") + handler.handleAllocatedContainers(Array(container1, container2)) + + handler.requestTotalExecutorsWithPreferredLocalities(2, 0, Map()) + + val statuses = Seq(container1, container2).map { c => + ContainerStatus.newInstance(c.getId(), ContainerState.COMPLETE, "Failed", -1) + } + handler.updateResourceRequests() + handler.processCompletedContainers(statuses.toSeq) + handler.updateResourceRequests() + handler.getNumExecutorsRunning should be (0) + handler.getNumPendingAllocate should be (2) + handler.getNumExecutorsFailed should be (2) + handler.getNumUnexpectedContainerRelease should be (2) + } + test("memory exceeded diagnostic regexes") { val diagnostics = "Container [pid=12465,containerID=container_1412887393566_0003_01_000002] is running " +