diff --git a/R/README.md b/R/README.md index a6970e39b55f3..d7d65b4f0eca5 100644 --- a/R/README.md +++ b/R/README.md @@ -52,7 +52,7 @@ The SparkR documentation (Rd files and HTML files) are not a part of the source SparkR comes with several sample programs in the `examples/src/main/r` directory. To run one of them, use `./bin/sparkR `. For example: - ./bin/sparkR examples/src/main/r/pi.R local[2] + ./bin/sparkR examples/src/main/r/dataframe.R You can also run the unit-tests for SparkR by running (you need to install the [testthat](http://cran.r-project.org/web/packages/testthat/index.html) package first): @@ -63,5 +63,5 @@ You can also run the unit-tests for SparkR by running (you need to install the [ The `./bin/spark-submit` and `./bin/sparkR` can also be used to submit jobs to YARN clusters. You will need to set YARN conf dir before doing so. For example on CDH you can run ``` export YARN_CONF_DIR=/etc/hadoop/conf -./bin/spark-submit --master yarn examples/src/main/r/pi.R 4 +./bin/spark-submit --master yarn examples/src/main/r/dataframe.R ``` diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 64ffdcffc9caf..411126a377950 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -1,6 +1,9 @@ # Imports from base R importFrom(methods, setGeneric, setMethod, setOldClass) -useDynLib(SparkR, stringHashCode) + +# Disable native libraries till we figure out how to package it +# See SPARKR-7839 +#useDynLib(SparkR, stringHashCode) # S3 methods exported export("sparkR.init") diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index a7fa32e291fb1..ed8093c80d360 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -65,9 +65,9 @@ dataFrame <- function(sdf, isCached = FALSE) { #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' printSchema(df) #'} setMethod("printSchema", @@ -88,9 +88,9 @@ setMethod("printSchema", #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' dfSchema <- schema(df) #'} setMethod("schema", @@ -110,9 +110,9 @@ setMethod("schema", #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' explain(df, TRUE) #'} setMethod("explain", @@ -139,9 +139,9 @@ setMethod("explain", #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' isLocal(df) #'} setMethod("isLocal", @@ -162,9 +162,9 @@ setMethod("isLocal", #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' showDF(df) #'} setMethod("showDF", @@ -185,9 +185,9 @@ setMethod("showDF", #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' df #'} setMethod("show", "DataFrame", @@ -210,9 +210,9 @@ setMethod("show", "DataFrame", #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' dtypes(df) #'} setMethod("dtypes", @@ -234,9 +234,9 @@ setMethod("dtypes", #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' columns(df) #'} setMethod("columns", @@ -267,11 +267,11 @@ setMethod("names", #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' registerTempTable(df, "json_df") -#' new_df <- sql(sqlCtx, "SELECT * FROM json_df") +#' new_df <- sql(sqlContext, "SELECT * FROM json_df") #'} setMethod("registerTempTable", signature(x = "DataFrame", tableName = "character"), @@ -293,9 +293,9 @@ setMethod("registerTempTable", #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) -#' df <- read.df(sqlCtx, path, "parquet") -#' df2 <- read.df(sqlCtx, path2, "parquet") +#' sqlContext <- sparkRSQL.init(sc) +#' df <- read.df(sqlContext, path, "parquet") +#' df2 <- read.df(sqlContext, path2, "parquet") #' registerTempTable(df, "table1") #' insertInto(df2, "table1", overwrite = TRUE) #'} @@ -316,9 +316,9 @@ setMethod("insertInto", #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' cache(df) #'} setMethod("cache", @@ -341,9 +341,9 @@ setMethod("cache", #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' persist(df, "MEMORY_AND_DISK") #'} setMethod("persist", @@ -366,9 +366,9 @@ setMethod("persist", #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' persist(df, "MEMORY_AND_DISK") #' unpersist(df) #'} @@ -391,9 +391,9 @@ setMethod("unpersist", #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' newDF <- repartition(df, 2L) #'} setMethod("repartition", @@ -415,9 +415,9 @@ setMethod("repartition", # @examples #\dontrun{ # sc <- sparkR.init() -# sqlCtx <- sparkRSQL.init(sc) +# sqlContext <- sparkRSQL.init(sc) # path <- "path/to/file.json" -# df <- jsonFile(sqlCtx, path) +# df <- jsonFile(sqlContext, path) # newRDD <- toJSON(df) #} setMethod("toJSON", @@ -440,9 +440,9 @@ setMethod("toJSON", #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' saveAsParquetFile(df, "/tmp/sparkr-tmp/") #'} setMethod("saveAsParquetFile", @@ -461,9 +461,9 @@ setMethod("saveAsParquetFile", #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' distinctDF <- distinct(df) #'} setMethod("distinct", @@ -486,9 +486,9 @@ setMethod("distinct", #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' collect(sample(df, FALSE, 0.5)) #' collect(sample(df, TRUE, 0.5)) #'} @@ -523,9 +523,9 @@ setMethod("sample_frac", #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' count(df) #' } setMethod("count", @@ -545,9 +545,9 @@ setMethod("count", #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' collected <- collect(df) #' firstName <- collected[[1]]$name #' } @@ -580,9 +580,9 @@ setMethod("collect", #' @examples #' \dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' limitedDF <- limit(df, 10) #' } setMethod("limit", @@ -599,9 +599,9 @@ setMethod("limit", #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' take(df, 2) #' } setMethod("take", @@ -626,9 +626,9 @@ setMethod("take", #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' head(df) #' } setMethod("head", @@ -647,9 +647,9 @@ setMethod("head", #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' first(df) #' } setMethod("first", @@ -669,9 +669,9 @@ setMethod("first", # @examples #\dontrun{ # sc <- sparkR.init() -# sqlCtx <- sparkRSQL.init(sc) +# sqlContext <- sparkRSQL.init(sc) # path <- "path/to/file.json" -# df <- jsonFile(sqlCtx, path) +# df <- jsonFile(sqlContext, path) # rdd <- toRDD(df) # } setMethod("toRDD", @@ -938,9 +938,9 @@ setMethod("select", #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' selectExpr(df, "col1", "(col2 * 5) as newCol") #' } setMethod("selectExpr", @@ -964,9 +964,9 @@ setMethod("selectExpr", #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' newDF <- withColumn(df, "newCol", df$col1 * 5) #' } setMethod("withColumn", @@ -988,9 +988,9 @@ setMethod("withColumn", #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' newDF <- mutate(df, newCol = df$col1 * 5, newCol2 = df$col1 * 2) #' names(newDF) # Will contain newCol, newCol2 #' } @@ -1024,9 +1024,9 @@ setMethod("mutate", #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' newDF <- withColumnRenamed(df, "col1", "newCol1") #' } setMethod("withColumnRenamed", @@ -1055,9 +1055,9 @@ setMethod("withColumnRenamed", #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' newDF <- rename(df, col1 = df$newCol1) #' } setMethod("rename", @@ -1095,9 +1095,9 @@ setClassUnion("characterOrColumn", c("character", "Column")) #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' arrange(df, df$col1) #' arrange(df, "col1") #' arrange(df, asc(df$col1), desc(abs(df$col2))) @@ -1137,9 +1137,9 @@ setMethod("orderBy", #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' filter(df, "col1 > 0") #' filter(df, df$col2 != "abcdefg") #' } @@ -1177,9 +1177,9 @@ setMethod("where", #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) -#' df1 <- jsonFile(sqlCtx, path) -#' df2 <- jsonFile(sqlCtx, path2) +#' sqlContext <- sparkRSQL.init(sc) +#' df1 <- jsonFile(sqlContext, path) +#' df2 <- jsonFile(sqlContext, path2) #' join(df1, df2) # Performs a Cartesian #' join(df1, df2, df1$col1 == df2$col2) # Performs an inner join based on expression #' join(df1, df2, df1$col1 == df2$col2, "right_outer") @@ -1219,9 +1219,9 @@ setMethod("join", #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) -#' df1 <- jsonFile(sqlCtx, path) -#' df2 <- jsonFile(sqlCtx, path2) +#' sqlContext <- sparkRSQL.init(sc) +#' df1 <- jsonFile(sqlContext, path) +#' df2 <- jsonFile(sqlContext, path2) #' unioned <- unionAll(df, df2) #' } setMethod("unionAll", @@ -1244,9 +1244,9 @@ setMethod("unionAll", #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) -#' df1 <- jsonFile(sqlCtx, path) -#' df2 <- jsonFile(sqlCtx, path2) +#' sqlContext <- sparkRSQL.init(sc) +#' df1 <- jsonFile(sqlContext, path) +#' df2 <- jsonFile(sqlContext, path2) #' intersectDF <- intersect(df, df2) #' } setMethod("intersect", @@ -1269,9 +1269,9 @@ setMethod("intersect", #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) -#' df1 <- jsonFile(sqlCtx, path) -#' df2 <- jsonFile(sqlCtx, path2) +#' sqlContext <- sparkRSQL.init(sc) +#' df1 <- jsonFile(sqlContext, path) +#' df2 <- jsonFile(sqlContext, path2) #' exceptDF <- except(df, df2) #' } #' @rdname except @@ -1308,9 +1308,9 @@ setMethod("except", #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' write.df(df, "myfile", "parquet", "overwrite") #' } setMethod("write.df", @@ -1318,8 +1318,8 @@ setMethod("write.df", mode = 'character'), function(df, path = NULL, source = NULL, mode = "append", ...){ if (is.null(source)) { - sqlCtx <- get(".sparkRSQLsc", envir = .sparkREnv) - source <- callJMethod(sqlCtx, "getConf", "spark.sql.sources.default", + sqlContext <- get(".sparkRSQLsc", envir = .sparkREnv) + source <- callJMethod(sqlContext, "getConf", "spark.sql.sources.default", "org.apache.spark.sql.parquet") } allModes <- c("append", "overwrite", "error", "ignore") @@ -1371,9 +1371,9 @@ setMethod("saveDF", #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' saveAsTable(df, "myfile") #' } setMethod("saveAsTable", @@ -1381,8 +1381,8 @@ setMethod("saveAsTable", mode = 'character'), function(df, tableName, source = NULL, mode="append", ...){ if (is.null(source)) { - sqlCtx <- get(".sparkRSQLsc", envir = .sparkREnv) - source <- callJMethod(sqlCtx, "getConf", "spark.sql.sources.default", + sqlContext <- get(".sparkRSQLsc", envir = .sparkREnv) + source <- callJMethod(sqlContext, "getConf", "spark.sql.sources.default", "org.apache.spark.sql.parquet") } allModes <- c("append", "overwrite", "error", "ignore") @@ -1408,9 +1408,9 @@ setMethod("saveAsTable", #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' describe(df) #' describe(df, "col1") #' describe(df, "col1", "col2") diff --git a/R/pkg/R/RDD.R b/R/pkg/R/RDD.R index d3a68fff780ce..0513299515644 100644 --- a/R/pkg/R/RDD.R +++ b/R/pkg/R/RDD.R @@ -239,7 +239,7 @@ setMethod("cache", # @aliases persist,RDD-method setMethod("persist", signature(x = "RDD", newLevel = "character"), - function(x, newLevel) { + function(x, newLevel = "MEMORY_ONLY") { callJMethod(getJRDD(x), "persist", getStorageLevel(newLevel)) x@env$isCached <- TRUE x diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R index 531442e8459e4..36cc612875879 100644 --- a/R/pkg/R/SQLContext.R +++ b/R/pkg/R/SQLContext.R @@ -69,7 +69,7 @@ infer_type <- function(x) { #' #' Converts an RDD to a DataFrame by infer the types. #' -#' @param sqlCtx A SQLContext +#' @param sqlContext A SQLContext #' @param data An RDD or list or data.frame #' @param schema a list of column names or named list (StructType), optional #' @return an DataFrame @@ -77,13 +77,13 @@ infer_type <- function(x) { #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' rdd <- lapply(parallelize(sc, 1:10), function(x) list(a=x, b=as.character(x))) -#' df <- createDataFrame(sqlCtx, rdd) +#' df <- createDataFrame(sqlContext, rdd) #' } # TODO(davies): support sampling and infer type from NA -createDataFrame <- function(sqlCtx, data, schema = NULL, samplingRatio = 1.0) { +createDataFrame <- function(sqlContext, data, schema = NULL, samplingRatio = 1.0) { if (is.data.frame(data)) { # get the names of columns, they will be put into RDD schema <- names(data) @@ -102,7 +102,7 @@ createDataFrame <- function(sqlCtx, data, schema = NULL, samplingRatio = 1.0) { }) } if (is.list(data)) { - sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sqlCtx) + sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sqlContext) rdd <- parallelize(sc, data) } else if (inherits(data, "RDD")) { rdd <- data @@ -146,7 +146,7 @@ createDataFrame <- function(sqlCtx, data, schema = NULL, samplingRatio = 1.0) { jrdd <- getJRDD(lapply(rdd, function(x) x), "row") srdd <- callJMethod(jrdd, "rdd") sdf <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "createDF", - srdd, schema$jobj, sqlCtx) + srdd, schema$jobj, sqlContext) dataFrame(sdf) } @@ -161,7 +161,7 @@ createDataFrame <- function(sqlCtx, data, schema = NULL, samplingRatio = 1.0) { # @examples #\dontrun{ # sc <- sparkR.init() -# sqlCtx <- sparkRSQL.init(sc) +# sqlContext <- sparkRSQL.init(sc) # rdd <- lapply(parallelize(sc, 1:10), function(x) list(a=x, b=as.character(x))) # df <- toDF(rdd) # } @@ -170,14 +170,14 @@ setGeneric("toDF", function(x, ...) { standardGeneric("toDF") }) setMethod("toDF", signature(x = "RDD"), function(x, ...) { - sqlCtx <- if (exists(".sparkRHivesc", envir = .sparkREnv)) { + sqlContext <- if (exists(".sparkRHivesc", envir = .sparkREnv)) { get(".sparkRHivesc", envir = .sparkREnv) } else if (exists(".sparkRSQLsc", envir = .sparkREnv)) { get(".sparkRSQLsc", envir = .sparkREnv) } else { stop("no SQL context available") } - createDataFrame(sqlCtx, x, ...) + createDataFrame(sqlContext, x, ...) }) #' Create a DataFrame from a JSON file. @@ -185,24 +185,24 @@ setMethod("toDF", signature(x = "RDD"), #' Loads a JSON file (one object per line), returning the result as a DataFrame #' It goes through the entire dataset once to determine the schema. #' -#' @param sqlCtx SQLContext to use +#' @param sqlContext SQLContext to use #' @param path Path of file to read. A vector of multiple paths is allowed. #' @return DataFrame #' @export #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' } -jsonFile <- function(sqlCtx, path) { +jsonFile <- function(sqlContext, path) { # Allow the user to have a more flexible definiton of the text file path path <- normalizePath(path) # Convert a string vector of paths to a string containing comma separated paths path <- paste(path, collapse = ",") - sdf <- callJMethod(sqlCtx, "jsonFile", path) + sdf <- callJMethod(sqlContext, "jsonFile", path) dataFrame(sdf) } @@ -211,7 +211,7 @@ jsonFile <- function(sqlCtx, path) { # # Loads an RDD storing one JSON object per string as a DataFrame. # -# @param sqlCtx SQLContext to use +# @param sqlContext SQLContext to use # @param rdd An RDD of JSON string # @param schema A StructType object to use as schema # @param samplingRatio The ratio of simpling used to infer the schema @@ -220,16 +220,16 @@ jsonFile <- function(sqlCtx, path) { # @examples #\dontrun{ # sc <- sparkR.init() -# sqlCtx <- sparkRSQL.init(sc) +# sqlContext <- sparkRSQL.init(sc) # rdd <- texFile(sc, "path/to/json") -# df <- jsonRDD(sqlCtx, rdd) +# df <- jsonRDD(sqlContext, rdd) # } # TODO: support schema -jsonRDD <- function(sqlCtx, rdd, schema = NULL, samplingRatio = 1.0) { +jsonRDD <- function(sqlContext, rdd, schema = NULL, samplingRatio = 1.0) { rdd <- serializeToString(rdd) if (is.null(schema)) { - sdf <- callJMethod(sqlCtx, "jsonRDD", callJMethod(getJRDD(rdd), "rdd"), samplingRatio) + sdf <- callJMethod(sqlContext, "jsonRDD", callJMethod(getJRDD(rdd), "rdd"), samplingRatio) dataFrame(sdf) } else { stop("not implemented") @@ -241,64 +241,63 @@ jsonRDD <- function(sqlCtx, rdd, schema = NULL, samplingRatio = 1.0) { #' #' Loads a Parquet file, returning the result as a DataFrame. #' -#' @param sqlCtx SQLContext to use +#' @param sqlContext SQLContext to use #' @param ... Path(s) of parquet file(s) to read. #' @return DataFrame #' @export # TODO: Implement saveasParquetFile and write examples for both -parquetFile <- function(sqlCtx, ...) { +parquetFile <- function(sqlContext, ...) { # Allow the user to have a more flexible definiton of the text file path paths <- lapply(list(...), normalizePath) - sdf <- callJMethod(sqlCtx, "parquetFile", paths) + sdf <- callJMethod(sqlContext, "parquetFile", paths) dataFrame(sdf) } #' SQL Query -#' +#' #' Executes a SQL query using Spark, returning the result as a DataFrame. #' -#' @param sqlCtx SQLContext to use +#' @param sqlContext SQLContext to use #' @param sqlQuery A character vector containing the SQL query #' @return DataFrame #' @export #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' registerTempTable(df, "table") -#' new_df <- sql(sqlCtx, "SELECT * FROM table") +#' new_df <- sql(sqlContext, "SELECT * FROM table") #' } -sql <- function(sqlCtx, sqlQuery) { - sdf <- callJMethod(sqlCtx, "sql", sqlQuery) - dataFrame(sdf) +sql <- function(sqlContext, sqlQuery) { + sdf <- callJMethod(sqlContext, "sql", sqlQuery) + dataFrame(sdf) } - #' Create a DataFrame from a SparkSQL Table #' #' Returns the specified Table as a DataFrame. The Table must have already been registered #' in the SQLContext. #' -#' @param sqlCtx SQLContext to use +#' @param sqlContext SQLContext to use #' @param tableName The SparkSQL Table to convert to a DataFrame. #' @return DataFrame #' @export #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' registerTempTable(df, "table") -#' new_df <- table(sqlCtx, "table") +#' new_df <- table(sqlContext, "table") #' } -table <- function(sqlCtx, tableName) { - sdf <- callJMethod(sqlCtx, "table", tableName) +table <- function(sqlContext, tableName) { + sdf <- callJMethod(sqlContext, "table", tableName) dataFrame(sdf) } @@ -307,22 +306,22 @@ table <- function(sqlCtx, tableName) { #' #' Returns a DataFrame containing names of tables in the given database. #' -#' @param sqlCtx SQLContext to use +#' @param sqlContext SQLContext to use #' @param databaseName name of the database #' @return a DataFrame #' @export #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) -#' tables(sqlCtx, "hive") +#' sqlContext <- sparkRSQL.init(sc) +#' tables(sqlContext, "hive") #' } -tables <- function(sqlCtx, databaseName = NULL) { +tables <- function(sqlContext, databaseName = NULL) { jdf <- if (is.null(databaseName)) { - callJMethod(sqlCtx, "tables") + callJMethod(sqlContext, "tables") } else { - callJMethod(sqlCtx, "tables", databaseName) + callJMethod(sqlContext, "tables", databaseName) } dataFrame(jdf) } @@ -332,22 +331,22 @@ tables <- function(sqlCtx, databaseName = NULL) { #' #' Returns the names of tables in the given database as an array. #' -#' @param sqlCtx SQLContext to use +#' @param sqlContext SQLContext to use #' @param databaseName name of the database #' @return a list of table names #' @export #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) -#' tableNames(sqlCtx, "hive") +#' sqlContext <- sparkRSQL.init(sc) +#' tableNames(sqlContext, "hive") #' } -tableNames <- function(sqlCtx, databaseName = NULL) { +tableNames <- function(sqlContext, databaseName = NULL) { if (is.null(databaseName)) { - callJMethod(sqlCtx, "tableNames") + callJMethod(sqlContext, "tableNames") } else { - callJMethod(sqlCtx, "tableNames", databaseName) + callJMethod(sqlContext, "tableNames", databaseName) } } @@ -356,58 +355,58 @@ tableNames <- function(sqlCtx, databaseName = NULL) { #' #' Caches the specified table in-memory. #' -#' @param sqlCtx SQLContext to use +#' @param sqlContext SQLContext to use #' @param tableName The name of the table being cached #' @return DataFrame #' @export #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' registerTempTable(df, "table") -#' cacheTable(sqlCtx, "table") +#' cacheTable(sqlContext, "table") #' } -cacheTable <- function(sqlCtx, tableName) { - callJMethod(sqlCtx, "cacheTable", tableName) +cacheTable <- function(sqlContext, tableName) { + callJMethod(sqlContext, "cacheTable", tableName) } #' Uncache Table #' #' Removes the specified table from the in-memory cache. #' -#' @param sqlCtx SQLContext to use +#' @param sqlContext SQLContext to use #' @param tableName The name of the table being uncached #' @return DataFrame #' @export #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' registerTempTable(df, "table") -#' uncacheTable(sqlCtx, "table") +#' uncacheTable(sqlContext, "table") #' } -uncacheTable <- function(sqlCtx, tableName) { - callJMethod(sqlCtx, "uncacheTable", tableName) +uncacheTable <- function(sqlContext, tableName) { + callJMethod(sqlContext, "uncacheTable", tableName) } #' Clear Cache #' #' Removes all cached tables from the in-memory cache. #' -#' @param sqlCtx SQLContext to use +#' @param sqlContext SQLContext to use #' @examples #' \dontrun{ -#' clearCache(sqlCtx) +#' clearCache(sqlContext) #' } -clearCache <- function(sqlCtx) { - callJMethod(sqlCtx, "clearCache") +clearCache <- function(sqlContext) { + callJMethod(sqlContext, "clearCache") } #' Drop Temporary Table @@ -415,22 +414,22 @@ clearCache <- function(sqlCtx) { #' Drops the temporary table with the given table name in the catalog. #' If the table has been cached/persisted before, it's also unpersisted. #' -#' @param sqlCtx SQLContext to use +#' @param sqlContext SQLContext to use #' @param tableName The name of the SparkSQL table to be dropped. #' @examples #' \dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) -#' df <- read.df(sqlCtx, path, "parquet") +#' sqlContext <- sparkRSQL.init(sc) +#' df <- read.df(sqlContext, path, "parquet") #' registerTempTable(df, "table") -#' dropTempTable(sqlCtx, "table") +#' dropTempTable(sqlContext, "table") #' } -dropTempTable <- function(sqlCtx, tableName) { +dropTempTable <- function(sqlContext, tableName) { if (class(tableName) != "character") { stop("tableName must be a string.") } - callJMethod(sqlCtx, "dropTempTable", tableName) + callJMethod(sqlContext, "dropTempTable", tableName) } #' Load an DataFrame @@ -441,7 +440,7 @@ dropTempTable <- function(sqlCtx, tableName) { #' If `source` is not specified, the default data source configured by #' "spark.sql.sources.default" will be used. #' -#' @param sqlCtx SQLContext to use +#' @param sqlContext SQLContext to use #' @param path The path of files to load #' @param source the name of external data source #' @return DataFrame @@ -449,24 +448,24 @@ dropTempTable <- function(sqlCtx, tableName) { #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) -#' df <- read.df(sqlCtx, "path/to/file.json", source = "json") +#' sqlContext <- sparkRSQL.init(sc) +#' df <- read.df(sqlContext, "path/to/file.json", source = "json") #' } -read.df <- function(sqlCtx, path = NULL, source = NULL, ...) { +read.df <- function(sqlContext, path = NULL, source = NULL, ...) { options <- varargsToEnv(...) if (!is.null(path)) { options[['path']] <- path } - sdf <- callJMethod(sqlCtx, "load", source, options) + sdf <- callJMethod(sqlContext, "load", source, options) dataFrame(sdf) } #' @aliases loadDF #' @export -loadDF <- function(sqlCtx, path = NULL, source = NULL, ...) { - read.df(sqlCtx, path, source, ...) +loadDF <- function(sqlContext, path = NULL, source = NULL, ...) { + read.df(sqlContext, path, source, ...) } #' Create an external table @@ -478,7 +477,7 @@ loadDF <- function(sqlCtx, path = NULL, source = NULL, ...) { #' If `source` is not specified, the default data source configured by #' "spark.sql.sources.default" will be used. #' -#' @param sqlCtx SQLContext to use +#' @param sqlContext SQLContext to use #' @param tableName A name of the table #' @param path The path of files to load #' @param source the name of external data source @@ -487,15 +486,15 @@ loadDF <- function(sqlCtx, path = NULL, source = NULL, ...) { #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) -#' df <- sparkRSQL.createExternalTable(sqlCtx, "myjson", path="path/to/json", source="json") +#' sqlContext <- sparkRSQL.init(sc) +#' df <- sparkRSQL.createExternalTable(sqlContext, "myjson", path="path/to/json", source="json") #' } -createExternalTable <- function(sqlCtx, tableName, path = NULL, source = NULL, ...) { +createExternalTable <- function(sqlContext, tableName, path = NULL, source = NULL, ...) { options <- varargsToEnv(...) if (!is.null(path)) { options[['path']] <- path } - sdf <- callJMethod(sqlCtx, "createExternalTable", tableName, source, options) + sdf <- callJMethod(sqlContext, "createExternalTable", tableName, source, options) dataFrame(sdf) } diff --git a/R/pkg/R/pairRDD.R b/R/pkg/R/pairRDD.R index 7694652856da5..1e24286dbcae2 100644 --- a/R/pkg/R/pairRDD.R +++ b/R/pkg/R/pairRDD.R @@ -329,7 +329,7 @@ setMethod("reduceByKey", convertEnvsToList(keys, vals) } locallyReduced <- lapplyPartition(x, reduceVals) - shuffled <- partitionBy(locallyReduced, numPartitions) + shuffled <- partitionBy(locallyReduced, numToInt(numPartitions)) lapplyPartition(shuffled, reduceVals) }) @@ -436,7 +436,7 @@ setMethod("combineByKey", convertEnvsToList(keys, combiners) } locallyCombined <- lapplyPartition(x, combineLocally) - shuffled <- partitionBy(locallyCombined, numPartitions) + shuffled <- partitionBy(locallyCombined, numToInt(numPartitions)) mergeAfterShuffle <- function(part) { combiners <- new.env() keys <- new.env() diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R index bc82df01f0fff..68387f0f5365d 100644 --- a/R/pkg/R/sparkR.R +++ b/R/pkg/R/sparkR.R @@ -222,7 +222,7 @@ sparkR.init <- function( #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #'} sparkRSQL.init <- function(jsc) { @@ -230,11 +230,11 @@ sparkRSQL.init <- function(jsc) { return(get(".sparkRSQLsc", envir = .sparkREnv)) } - sqlCtx <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", + sqlContext <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "createSQLContext", jsc) - assign(".sparkRSQLsc", sqlCtx, envir = .sparkREnv) - sqlCtx + assign(".sparkRSQLsc", sqlContext, envir = .sparkREnv) + sqlContext } #' Initialize a new HiveContext. @@ -246,7 +246,7 @@ sparkRSQL.init <- function(jsc) { #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRHive.init(sc) +#' sqlContext <- sparkRHive.init(sc) #'} sparkRHive.init <- function(jsc) { diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R index 0e7b7bd5a5b34..69b2700191c9a 100644 --- a/R/pkg/R/utils.R +++ b/R/pkg/R/utils.R @@ -122,13 +122,49 @@ hashCode <- function(key) { intBits <- packBits(rawToBits(rawVec), "integer") as.integer(bitwXor(intBits[2], intBits[1])) } else if (class(key) == "character") { - .Call("stringHashCode", key) + # TODO: SPARK-7839 means we might not have the native library available + if (is.loaded("stringHashCode")) { + .Call("stringHashCode", key) + } else { + n <- nchar(key) + if (n == 0) { + 0L + } else { + asciiVals <- sapply(charToRaw(key), function(x) { strtoi(x, 16L) }) + hashC <- 0 + for (k in 1:length(asciiVals)) { + hashC <- mult31AndAdd(hashC, asciiVals[k]) + } + as.integer(hashC) + } + } } else { warning(paste("Could not hash object, returning 0", sep = "")) as.integer(0) } } +# Helper function used to wrap a 'numeric' value to integer bounds. +# Useful for implementing C-like integer arithmetic +wrapInt <- function(value) { + if (value > .Machine$integer.max) { + value <- value - 2 * .Machine$integer.max - 2 + } else if (value < -1 * .Machine$integer.max) { + value <- 2 * .Machine$integer.max + value + 2 + } + value +} + +# Multiply `val` by 31 and add `addVal` to the result. Ensures that +# integer-overflows are handled at every step. +mult31AndAdd <- function(val, addVal) { + vec <- c(bitwShiftL(val, c(4,3,2,1,0)), addVal) + Reduce(function(a, b) { + wrapInt(as.numeric(a) + as.numeric(b)) + }, + vec) +} + # Create a new RDD with serializedMode == "byte". # Return itself if already in "byte" format. serializeToBytes <- function(rdd) { diff --git a/R/pkg/inst/profile/shell.R b/R/pkg/inst/profile/shell.R index 33478d9e29995..ca94f1d4e7fd5 100644 --- a/R/pkg/inst/profile/shell.R +++ b/R/pkg/inst/profile/shell.R @@ -26,8 +26,8 @@ sc <- SparkR::sparkR.init(Sys.getenv("MASTER", unset = "")) assign("sc", sc, envir=.GlobalEnv) - sqlCtx <- SparkR::sparkRSQL.init(sc) - assign("sqlCtx", sqlCtx, envir=.GlobalEnv) + sqlContext <- SparkR::sparkRSQL.init(sc) + assign("sqlContext", sqlContext, envir=.GlobalEnv) cat("\n Welcome to SparkR!") - cat("\n Spark context is available as sc, SQL context is available as sqlCtx\n") + cat("\n Spark context is available as sc, SQL context is available as sqlContext\n") } diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index 1768c57fd02e4..1857e636e8577 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -23,7 +23,7 @@ context("SparkSQL functions") sc <- sparkR.init() -sqlCtx <- sparkRSQL.init(sc) +sqlContext <- sparkRSQL.init(sc) mockLines <- c("{\"name\":\"Michael\"}", "{\"name\":\"Andy\", \"age\":30}", @@ -67,25 +67,25 @@ test_that("structType and structField", { test_that("create DataFrame from RDD", { rdd <- lapply(parallelize(sc, 1:10), function(x) { list(x, as.character(x)) }) - df <- createDataFrame(sqlCtx, rdd, list("a", "b")) + df <- createDataFrame(sqlContext, rdd, list("a", "b")) expect_true(inherits(df, "DataFrame")) expect_true(count(df) == 10) expect_equal(columns(df), c("a", "b")) expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) - df <- createDataFrame(sqlCtx, rdd) + df <- createDataFrame(sqlContext, rdd) expect_true(inherits(df, "DataFrame")) expect_equal(columns(df), c("_1", "_2")) schema <- structType(structField(x = "a", type = "integer", nullable = TRUE), structField(x = "b", type = "string", nullable = TRUE)) - df <- createDataFrame(sqlCtx, rdd, schema) + df <- createDataFrame(sqlContext, rdd, schema) expect_true(inherits(df, "DataFrame")) expect_equal(columns(df), c("a", "b")) expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) rdd <- lapply(parallelize(sc, 1:10), function(x) { list(a = x, b = as.character(x)) }) - df <- createDataFrame(sqlCtx, rdd) + df <- createDataFrame(sqlContext, rdd) expect_true(inherits(df, "DataFrame")) expect_true(count(df) == 10) expect_equal(columns(df), c("a", "b")) @@ -121,17 +121,17 @@ test_that("toDF", { test_that("create DataFrame from list or data.frame", { l <- list(list(1, 2), list(3, 4)) - df <- createDataFrame(sqlCtx, l, c("a", "b")) + df <- createDataFrame(sqlContext, l, c("a", "b")) expect_equal(columns(df), c("a", "b")) l <- list(list(a=1, b=2), list(a=3, b=4)) - df <- createDataFrame(sqlCtx, l) + df <- createDataFrame(sqlContext, l) expect_equal(columns(df), c("a", "b")) a <- 1:3 b <- c("a", "b", "c") ldf <- data.frame(a, b) - df <- createDataFrame(sqlCtx, ldf) + df <- createDataFrame(sqlContext, ldf) expect_equal(columns(df), c("a", "b")) expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) expect_equal(count(df), 3) @@ -142,7 +142,7 @@ test_that("create DataFrame from list or data.frame", { test_that("create DataFrame with different data types", { l <- list(a = 1L, b = 2, c = TRUE, d = "ss", e = as.Date("2012-12-13"), f = as.POSIXct("2015-03-15 12:13:14.056")) - df <- createDataFrame(sqlCtx, list(l)) + df <- createDataFrame(sqlContext, list(l)) expect_equal(dtypes(df), list(c("a", "int"), c("b", "double"), c("c", "boolean"), c("d", "string"), c("e", "date"), c("f", "timestamp"))) expect_equal(count(df), 1) @@ -154,7 +154,7 @@ test_that("create DataFrame with different data types", { # e <- new.env() # assign("n", 3L, envir = e) # l <- list(1:10, list("a", "b"), e, list(a="aa", b=3L)) -# df <- createDataFrame(sqlCtx, list(l), c("a", "b", "c", "d")) +# df <- createDataFrame(sqlContext, list(l), c("a", "b", "c", "d")) # expect_equal(dtypes(df), list(c("a", "array"), c("b", "array"), # c("c", "map"), c("d", "struct"))) # expect_equal(count(df), 1) @@ -163,7 +163,7 @@ test_that("create DataFrame with different data types", { #}) test_that("jsonFile() on a local file returns a DataFrame", { - df <- jsonFile(sqlCtx, jsonPath) + df <- jsonFile(sqlContext, jsonPath) expect_true(inherits(df, "DataFrame")) expect_true(count(df) == 3) }) @@ -171,88 +171,88 @@ test_that("jsonFile() on a local file returns a DataFrame", { test_that("jsonRDD() on a RDD with json string", { rdd <- parallelize(sc, mockLines) expect_true(count(rdd) == 3) - df <- jsonRDD(sqlCtx, rdd) + df <- jsonRDD(sqlContext, rdd) expect_true(inherits(df, "DataFrame")) expect_true(count(df) == 3) rdd2 <- flatMap(rdd, function(x) c(x, x)) - df <- jsonRDD(sqlCtx, rdd2) + df <- jsonRDD(sqlContext, rdd2) expect_true(inherits(df, "DataFrame")) expect_true(count(df) == 6) }) test_that("test cache, uncache and clearCache", { - df <- jsonFile(sqlCtx, jsonPath) + df <- jsonFile(sqlContext, jsonPath) registerTempTable(df, "table1") - cacheTable(sqlCtx, "table1") - uncacheTable(sqlCtx, "table1") - clearCache(sqlCtx) - dropTempTable(sqlCtx, "table1") + cacheTable(sqlContext, "table1") + uncacheTable(sqlContext, "table1") + clearCache(sqlContext) + dropTempTable(sqlContext, "table1") }) test_that("test tableNames and tables", { - df <- jsonFile(sqlCtx, jsonPath) + df <- jsonFile(sqlContext, jsonPath) registerTempTable(df, "table1") - expect_true(length(tableNames(sqlCtx)) == 1) - df <- tables(sqlCtx) + expect_true(length(tableNames(sqlContext)) == 1) + df <- tables(sqlContext) expect_true(count(df) == 1) - dropTempTable(sqlCtx, "table1") + dropTempTable(sqlContext, "table1") }) test_that("registerTempTable() results in a queryable table and sql() results in a new DataFrame", { - df <- jsonFile(sqlCtx, jsonPath) + df <- jsonFile(sqlContext, jsonPath) registerTempTable(df, "table1") - newdf <- sql(sqlCtx, "SELECT * FROM table1 where name = 'Michael'") + newdf <- sql(sqlContext, "SELECT * FROM table1 where name = 'Michael'") expect_true(inherits(newdf, "DataFrame")) expect_true(count(newdf) == 1) - dropTempTable(sqlCtx, "table1") + dropTempTable(sqlContext, "table1") }) test_that("insertInto() on a registered table", { - df <- read.df(sqlCtx, jsonPath, "json") + df <- read.df(sqlContext, jsonPath, "json") write.df(df, parquetPath, "parquet", "overwrite") - dfParquet <- read.df(sqlCtx, parquetPath, "parquet") + dfParquet <- read.df(sqlContext, parquetPath, "parquet") lines <- c("{\"name\":\"Bob\", \"age\":24}", "{\"name\":\"James\", \"age\":35}") jsonPath2 <- tempfile(pattern="jsonPath2", fileext=".tmp") parquetPath2 <- tempfile(pattern = "parquetPath2", fileext = ".parquet") writeLines(lines, jsonPath2) - df2 <- read.df(sqlCtx, jsonPath2, "json") + df2 <- read.df(sqlContext, jsonPath2, "json") write.df(df2, parquetPath2, "parquet", "overwrite") - dfParquet2 <- read.df(sqlCtx, parquetPath2, "parquet") + dfParquet2 <- read.df(sqlContext, parquetPath2, "parquet") registerTempTable(dfParquet, "table1") insertInto(dfParquet2, "table1") - expect_true(count(sql(sqlCtx, "select * from table1")) == 5) - expect_true(first(sql(sqlCtx, "select * from table1 order by age"))$name == "Michael") - dropTempTable(sqlCtx, "table1") + expect_true(count(sql(sqlContext, "select * from table1")) == 5) + expect_true(first(sql(sqlContext, "select * from table1 order by age"))$name == "Michael") + dropTempTable(sqlContext, "table1") registerTempTable(dfParquet, "table1") insertInto(dfParquet2, "table1", overwrite = TRUE) - expect_true(count(sql(sqlCtx, "select * from table1")) == 2) - expect_true(first(sql(sqlCtx, "select * from table1 order by age"))$name == "Bob") - dropTempTable(sqlCtx, "table1") + expect_true(count(sql(sqlContext, "select * from table1")) == 2) + expect_true(first(sql(sqlContext, "select * from table1 order by age"))$name == "Bob") + dropTempTable(sqlContext, "table1") }) test_that("table() returns a new DataFrame", { - df <- jsonFile(sqlCtx, jsonPath) + df <- jsonFile(sqlContext, jsonPath) registerTempTable(df, "table1") - tabledf <- table(sqlCtx, "table1") + tabledf <- table(sqlContext, "table1") expect_true(inherits(tabledf, "DataFrame")) expect_true(count(tabledf) == 3) - dropTempTable(sqlCtx, "table1") + dropTempTable(sqlContext, "table1") }) test_that("toRDD() returns an RRDD", { - df <- jsonFile(sqlCtx, jsonPath) + df <- jsonFile(sqlContext, jsonPath) testRDD <- toRDD(df) expect_true(inherits(testRDD, "RDD")) expect_true(count(testRDD) == 3) }) test_that("union on two RDDs created from DataFrames returns an RRDD", { - df <- jsonFile(sqlCtx, jsonPath) + df <- jsonFile(sqlContext, jsonPath) RDD1 <- toRDD(df) RDD2 <- toRDD(df) unioned <- unionRDD(RDD1, RDD2) @@ -274,7 +274,7 @@ test_that("union on mixed serialization types correctly returns a byte RRDD", { writeLines(textLines, textPath) textRDD <- textFile(sc, textPath) - df <- jsonFile(sqlCtx, jsonPath) + df <- jsonFile(sqlContext, jsonPath) dfRDD <- toRDD(df) unionByte <- unionRDD(rdd, dfRDD) @@ -292,7 +292,7 @@ test_that("union on mixed serialization types correctly returns a byte RRDD", { test_that("objectFile() works with row serialization", { objectPath <- tempfile(pattern="spark-test", fileext=".tmp") - df <- jsonFile(sqlCtx, jsonPath) + df <- jsonFile(sqlContext, jsonPath) dfRDD <- toRDD(df) saveAsObjectFile(coalesce(dfRDD, 1L), objectPath) objectIn <- objectFile(sc, objectPath) @@ -303,7 +303,7 @@ test_that("objectFile() works with row serialization", { }) test_that("lapply() on a DataFrame returns an RDD with the correct columns", { - df <- jsonFile(sqlCtx, jsonPath) + df <- jsonFile(sqlContext, jsonPath) testRDD <- lapply(df, function(row) { row$newCol <- row$age + 5 row @@ -315,7 +315,7 @@ test_that("lapply() on a DataFrame returns an RDD with the correct columns", { }) test_that("collect() returns a data.frame", { - df <- jsonFile(sqlCtx, jsonPath) + df <- jsonFile(sqlContext, jsonPath) rdf <- collect(df) expect_true(is.data.frame(rdf)) expect_true(names(rdf)[1] == "age") @@ -324,20 +324,20 @@ test_that("collect() returns a data.frame", { }) test_that("limit() returns DataFrame with the correct number of rows", { - df <- jsonFile(sqlCtx, jsonPath) + df <- jsonFile(sqlContext, jsonPath) dfLimited <- limit(df, 2) expect_true(inherits(dfLimited, "DataFrame")) expect_true(count(dfLimited) == 2) }) test_that("collect() and take() on a DataFrame return the same number of rows and columns", { - df <- jsonFile(sqlCtx, jsonPath) + df <- jsonFile(sqlContext, jsonPath) expect_true(nrow(collect(df)) == nrow(take(df, 10))) expect_true(ncol(collect(df)) == ncol(take(df, 10))) }) test_that("multiple pipeline transformations starting with a DataFrame result in an RDD with the correct values", { - df <- jsonFile(sqlCtx, jsonPath) + df <- jsonFile(sqlContext, jsonPath) first <- lapply(df, function(row) { row$age <- row$age + 5 row @@ -354,7 +354,7 @@ test_that("multiple pipeline transformations starting with a DataFrame result in }) test_that("cache(), persist(), and unpersist() on a DataFrame", { - df <- jsonFile(sqlCtx, jsonPath) + df <- jsonFile(sqlContext, jsonPath) expect_false(df@env$isCached) cache(df) expect_true(df@env$isCached) @@ -373,7 +373,7 @@ test_that("cache(), persist(), and unpersist() on a DataFrame", { }) test_that("schema(), dtypes(), columns(), names() return the correct values/format", { - df <- jsonFile(sqlCtx, jsonPath) + df <- jsonFile(sqlContext, jsonPath) testSchema <- schema(df) expect_true(length(testSchema$fields()) == 2) expect_true(testSchema$fields()[[1]]$dataType.toString() == "LongType") @@ -394,7 +394,7 @@ test_that("schema(), dtypes(), columns(), names() return the correct values/form }) test_that("head() and first() return the correct data", { - df <- jsonFile(sqlCtx, jsonPath) + df <- jsonFile(sqlContext, jsonPath) testHead <- head(df) expect_true(nrow(testHead) == 3) expect_true(ncol(testHead) == 2) @@ -415,14 +415,14 @@ test_that("distinct() on DataFrames", { jsonPathWithDup <- tempfile(pattern="sparkr-test", fileext=".tmp") writeLines(lines, jsonPathWithDup) - df <- jsonFile(sqlCtx, jsonPathWithDup) + df <- jsonFile(sqlContext, jsonPathWithDup) uniques <- distinct(df) expect_true(inherits(uniques, "DataFrame")) expect_true(count(uniques) == 3) }) test_that("sample on a DataFrame", { - df <- jsonFile(sqlCtx, jsonPath) + df <- jsonFile(sqlContext, jsonPath) sampled <- sample(df, FALSE, 1.0) expect_equal(nrow(collect(sampled)), count(df)) expect_true(inherits(sampled, "DataFrame")) @@ -435,7 +435,7 @@ test_that("sample on a DataFrame", { }) test_that("select operators", { - df <- select(jsonFile(sqlCtx, jsonPath), "name", "age") + df <- select(jsonFile(sqlContext, jsonPath), "name", "age") expect_true(inherits(df$name, "Column")) expect_true(inherits(df[[2]], "Column")) expect_true(inherits(df[["age"]], "Column")) @@ -461,7 +461,7 @@ test_that("select operators", { }) test_that("select with column", { - df <- jsonFile(sqlCtx, jsonPath) + df <- jsonFile(sqlContext, jsonPath) df1 <- select(df, "name") expect_true(columns(df1) == c("name")) expect_true(count(df1) == 3) @@ -472,7 +472,7 @@ test_that("select with column", { }) test_that("selectExpr() on a DataFrame", { - df <- jsonFile(sqlCtx, jsonPath) + df <- jsonFile(sqlContext, jsonPath) selected <- selectExpr(df, "age * 2") expect_true(names(selected) == "(age * 2)") expect_equal(collect(selected), collect(select(df, df$age * 2L))) @@ -483,7 +483,7 @@ test_that("selectExpr() on a DataFrame", { }) test_that("column calculation", { - df <- jsonFile(sqlCtx, jsonPath) + df <- jsonFile(sqlContext, jsonPath) d <- collect(select(df, alias(df$age + 1, "age2"))) expect_true(names(d) == c("age2")) df2 <- select(df, lower(df$name), abs(df$age)) @@ -492,15 +492,15 @@ test_that("column calculation", { }) test_that("read.df() from json file", { - df <- read.df(sqlCtx, jsonPath, "json") + df <- read.df(sqlContext, jsonPath, "json") expect_true(inherits(df, "DataFrame")) expect_true(count(df) == 3) }) test_that("write.df() as parquet file", { - df <- read.df(sqlCtx, jsonPath, "json") + df <- read.df(sqlContext, jsonPath, "json") write.df(df, parquetPath, "parquet", mode="overwrite") - df2 <- read.df(sqlCtx, parquetPath, "parquet") + df2 <- read.df(sqlContext, parquetPath, "parquet") expect_true(inherits(df2, "DataFrame")) expect_true(count(df2) == 3) }) @@ -553,7 +553,7 @@ test_that("column binary mathfunctions", { "{\"a\":4, \"b\":8}") jsonPathWithDup <- tempfile(pattern="sparkr-test", fileext=".tmp") writeLines(lines, jsonPathWithDup) - df <- jsonFile(sqlCtx, jsonPathWithDup) + df <- jsonFile(sqlContext, jsonPathWithDup) expect_equal(collect(select(df, atan2(df$a, df$b)))[1, "ATAN2(a, b)"], atan2(1, 5)) expect_equal(collect(select(df, atan2(df$a, df$b)))[2, "ATAN2(a, b)"], atan2(2, 6)) expect_equal(collect(select(df, atan2(df$a, df$b)))[3, "ATAN2(a, b)"], atan2(3, 7)) @@ -565,7 +565,7 @@ test_that("column binary mathfunctions", { }) test_that("string operators", { - df <- jsonFile(sqlCtx, jsonPath) + df <- jsonFile(sqlContext, jsonPath) expect_equal(count(where(df, like(df$name, "A%"))), 1) expect_equal(count(where(df, startsWith(df$name, "A"))), 1) expect_equal(first(select(df, substr(df$name, 1, 2)))[[1]], "Mi") @@ -573,7 +573,7 @@ test_that("string operators", { }) test_that("group by", { - df <- jsonFile(sqlCtx, jsonPath) + df <- jsonFile(sqlContext, jsonPath) df1 <- agg(df, name = "max", age = "sum") expect_true(1 == count(df1)) df1 <- agg(df, age2 = max(df$age)) @@ -610,7 +610,7 @@ test_that("group by", { }) test_that("arrange() and orderBy() on a DataFrame", { - df <- jsonFile(sqlCtx, jsonPath) + df <- jsonFile(sqlContext, jsonPath) sorted <- arrange(df, df$age) expect_true(collect(sorted)[1,2] == "Michael") @@ -627,7 +627,7 @@ test_that("arrange() and orderBy() on a DataFrame", { }) test_that("filter() on a DataFrame", { - df <- jsonFile(sqlCtx, jsonPath) + df <- jsonFile(sqlContext, jsonPath) filtered <- filter(df, "age > 20") expect_true(count(filtered) == 1) expect_true(collect(filtered)$name == "Andy") @@ -637,7 +637,7 @@ test_that("filter() on a DataFrame", { }) test_that("join() on a DataFrame", { - df <- jsonFile(sqlCtx, jsonPath) + df <- jsonFile(sqlContext, jsonPath) mockLines2 <- c("{\"name\":\"Michael\", \"test\": \"yes\"}", "{\"name\":\"Andy\", \"test\": \"no\"}", @@ -645,7 +645,7 @@ test_that("join() on a DataFrame", { "{\"name\":\"Bob\", \"test\": \"yes\"}") jsonPath2 <- tempfile(pattern="sparkr-test", fileext=".tmp") writeLines(mockLines2, jsonPath2) - df2 <- jsonFile(sqlCtx, jsonPath2) + df2 <- jsonFile(sqlContext, jsonPath2) joined <- join(df, df2) expect_equal(names(joined), c("age", "name", "name", "test")) @@ -668,7 +668,7 @@ test_that("join() on a DataFrame", { }) test_that("toJSON() returns an RDD of the correct values", { - df <- jsonFile(sqlCtx, jsonPath) + df <- jsonFile(sqlContext, jsonPath) testRDD <- toJSON(df) expect_true(inherits(testRDD, "RDD")) expect_true(SparkR:::getSerializedMode(testRDD) == "string") @@ -676,25 +676,25 @@ test_that("toJSON() returns an RDD of the correct values", { }) test_that("showDF()", { - df <- jsonFile(sqlCtx, jsonPath) + df <- jsonFile(sqlContext, jsonPath) s <- capture.output(showDF(df)) expect_output(s , "+----+-------+\n| age| name|\n+----+-------+\n|null|Michael|\n| 30| Andy|\n| 19| Justin|\n+----+-------+\n") }) test_that("isLocal()", { - df <- jsonFile(sqlCtx, jsonPath) + df <- jsonFile(sqlContext, jsonPath) expect_false(isLocal(df)) }) test_that("unionAll(), except(), and intersect() on a DataFrame", { - df <- jsonFile(sqlCtx, jsonPath) + df <- jsonFile(sqlContext, jsonPath) lines <- c("{\"name\":\"Bob\", \"age\":24}", "{\"name\":\"Andy\", \"age\":30}", "{\"name\":\"James\", \"age\":35}") jsonPath2 <- tempfile(pattern="sparkr-test", fileext=".tmp") writeLines(lines, jsonPath2) - df2 <- read.df(sqlCtx, jsonPath2, "json") + df2 <- read.df(sqlContext, jsonPath2, "json") unioned <- arrange(unionAll(df, df2), df$age) expect_true(inherits(unioned, "DataFrame")) @@ -713,7 +713,7 @@ test_that("unionAll(), except(), and intersect() on a DataFrame", { }) test_that("withColumn() and withColumnRenamed()", { - df <- jsonFile(sqlCtx, jsonPath) + df <- jsonFile(sqlContext, jsonPath) newDF <- withColumn(df, "newAge", df$age + 2) expect_true(length(columns(newDF)) == 3) expect_true(columns(newDF)[3] == "newAge") @@ -725,7 +725,7 @@ test_that("withColumn() and withColumnRenamed()", { }) test_that("mutate() and rename()", { - df <- jsonFile(sqlCtx, jsonPath) + df <- jsonFile(sqlContext, jsonPath) newDF <- mutate(df, newAge = df$age + 2) expect_true(length(columns(newDF)) == 3) expect_true(columns(newDF)[3] == "newAge") @@ -737,25 +737,25 @@ test_that("mutate() and rename()", { }) test_that("write.df() on DataFrame and works with parquetFile", { - df <- jsonFile(sqlCtx, jsonPath) + df <- jsonFile(sqlContext, jsonPath) write.df(df, parquetPath, "parquet", mode="overwrite") - parquetDF <- parquetFile(sqlCtx, parquetPath) + parquetDF <- parquetFile(sqlContext, parquetPath) expect_true(inherits(parquetDF, "DataFrame")) expect_equal(count(df), count(parquetDF)) }) test_that("parquetFile works with multiple input paths", { - df <- jsonFile(sqlCtx, jsonPath) + df <- jsonFile(sqlContext, jsonPath) write.df(df, parquetPath, "parquet", mode="overwrite") parquetPath2 <- tempfile(pattern = "parquetPath2", fileext = ".parquet") write.df(df, parquetPath2, "parquet", mode="overwrite") - parquetDF <- parquetFile(sqlCtx, parquetPath, parquetPath2) + parquetDF <- parquetFile(sqlContext, parquetPath, parquetPath2) expect_true(inherits(parquetDF, "DataFrame")) expect_true(count(parquetDF) == count(df)*2) }) test_that("describe() on a DataFrame", { - df <- jsonFile(sqlCtx, jsonPath) + df <- jsonFile(sqlContext, jsonPath) stats <- describe(df, "age") expect_equal(collect(stats)[1, "summary"], "count") expect_equal(collect(stats)[2, "age"], "24.5") diff --git a/R/pkg/src/Makefile b/R/pkg/src-native/Makefile similarity index 100% rename from R/pkg/src/Makefile rename to R/pkg/src-native/Makefile diff --git a/R/pkg/src/Makefile.win b/R/pkg/src-native/Makefile.win similarity index 100% rename from R/pkg/src/Makefile.win rename to R/pkg/src-native/Makefile.win diff --git a/R/pkg/src/string_hash_code.c b/R/pkg/src-native/string_hash_code.c similarity index 100% rename from R/pkg/src/string_hash_code.c rename to R/pkg/src-native/string_hash_code.c diff --git a/conf/metrics.properties.template b/conf/metrics.properties.template index 2e0cb5db170ac..7de0011a48ca8 100644 --- a/conf/metrics.properties.template +++ b/conf/metrics.properties.template @@ -126,9 +126,9 @@ #*.sink.slf4j.class=org.apache.spark.metrics.sink.Slf4jSink # Polling period for Slf4JSink -#*.sink.sl4j.period=1 +#*.sink.slf4j.period=1 -#*.sink.sl4j.unit=minutes +#*.sink.slf4j.unit=minutes # Enable jvm source for instance master, worker, driver and executor diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala index 66bda68088502..9514604752640 100644 --- a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala +++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala @@ -91,7 +91,7 @@ private[spark] class ExecutorAllocationManager( // How long there must be backlogged tasks for before an addition is triggered (seconds) private val schedulerBacklogTimeoutS = conf.getTimeAsSeconds( - "spark.dynamicAllocation.schedulerBacklogTimeout", "5s") + "spark.dynamicAllocation.schedulerBacklogTimeout", "1s") // Same as above, but used only after `schedulerBacklogTimeoutS` is exceeded private val sustainedSchedulerBacklogTimeoutS = conf.getTimeAsSeconds( @@ -99,7 +99,7 @@ private[spark] class ExecutorAllocationManager( // How long an executor must be idle for before it is removed (seconds) private val executorIdleTimeoutS = conf.getTimeAsSeconds( - "spark.dynamicAllocation.executorIdleTimeout", "600s") + "spark.dynamicAllocation.executorIdleTimeout", "60s") // During testing, the methods to actually kill and add executors are mocked out private val testing = conf.getBoolean("spark.dynamicAllocation.testing", false) @@ -268,6 +268,8 @@ private[spark] class ExecutorAllocationManager( numExecutorsTarget = math.max(maxNeeded, minNumExecutors) client.requestTotalExecutors(numExecutorsTarget) numExecutorsToAdd = 1 + logInfo(s"Lowering target number of executors to $numExecutorsTarget because " + + s"not all requests are actually needed (previously $oldNumExecutorsTarget)") numExecutorsTarget - oldNumExecutorsTarget } else if (addTime != NOT_SET && now >= addTime) { val delta = addExecutors(maxNeeded) @@ -292,9 +294,8 @@ private[spark] class ExecutorAllocationManager( private def addExecutors(maxNumExecutorsNeeded: Int): Int = { // Do not request more executors if it would put our target over the upper bound if (numExecutorsTarget >= maxNumExecutors) { - val numExecutorsPending = numExecutorsTarget - executorIds.size - logDebug(s"Not adding executors because there are already ${executorIds.size} registered " + - s"and ${numExecutorsPending} pending executor(s) (limit $maxNumExecutors)") + logDebug(s"Not adding executors because our current target total " + + s"is already $numExecutorsTarget (limit $maxNumExecutors)") numExecutorsToAdd = 1 return 0 } @@ -310,10 +311,19 @@ private[spark] class ExecutorAllocationManager( // Ensure that our target fits within configured bounds: numExecutorsTarget = math.max(math.min(numExecutorsTarget, maxNumExecutors), minNumExecutors) + val delta = numExecutorsTarget - oldNumExecutorsTarget + + // If our target has not changed, do not send a message + // to the cluster manager and reset our exponential growth + if (delta == 0) { + numExecutorsToAdd = 1 + return 0 + } + val addRequestAcknowledged = testing || client.requestTotalExecutors(numExecutorsTarget) if (addRequestAcknowledged) { - val delta = numExecutorsTarget - oldNumExecutorsTarget - logInfo(s"Requesting $delta new executor(s) because tasks are backlogged" + + val executorsString = "executor" + { if (delta > 1) "s" else "" } + logInfo(s"Requesting $delta new $executorsString because tasks are backlogged" + s" (new desired total will be $numExecutorsTarget)") numExecutorsToAdd = if (delta == numExecutorsToAdd) { numExecutorsToAdd * 2 @@ -420,7 +430,7 @@ private[spark] class ExecutorAllocationManager( * This resets all variables used for adding executors. */ private def onSchedulerQueueEmpty(): Unit = synchronized { - logDebug(s"Clearing timer to add executors because there are no more pending tasks") + logDebug("Clearing timer to add executors because there are no more pending tasks") addTime = NOT_SET numExecutorsToAdd = 1 } diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index ad78bdfde2dfb..ea6c0dea08e47 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -1884,7 +1884,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * * @param f the closure to clean * @param checkSerializable whether or not to immediately check f for serializability - * @throws SparkException if checkSerializable is set but f is not + * @throws SparkException if checkSerializable is set but f is not * serializable */ private[spark] def clean[F <: AnyRef](f: F, checkSerializable: Boolean = true): F = { diff --git a/core/src/main/scala/org/apache/spark/TestUtils.scala b/core/src/main/scala/org/apache/spark/TestUtils.scala index 398ca41e16151..fe6320b504e15 100644 --- a/core/src/main/scala/org/apache/spark/TestUtils.scala +++ b/core/src/main/scala/org/apache/spark/TestUtils.scala @@ -105,23 +105,18 @@ private[spark] object TestUtils { URI.create(s"string:///${name.replace(".", "/")}${SOURCE.extension}") } - private class JavaSourceFromString(val name: String, val code: String) + private[spark] class JavaSourceFromString(val name: String, val code: String) extends SimpleJavaFileObject(createURI(name), SOURCE) { override def getCharContent(ignoreEncodingErrors: Boolean): String = code } - /** Creates a compiled class with the given name. Class file will be placed in destDir. */ + /** Creates a compiled class with the source file. Class file will be placed in destDir. */ def createCompiledClass( className: String, destDir: File, - toStringValue: String = "", - baseClass: String = null, - classpathUrls: Seq[URL] = Seq()): File = { + sourceFile: JavaSourceFromString, + classpathUrls: Seq[URL]): File = { val compiler = ToolProvider.getSystemJavaCompiler - val extendsText = Option(baseClass).map { c => s" extends ${c}" }.getOrElse("") - val sourceFile = new JavaSourceFromString(className, - "public class " + className + extendsText + " implements java.io.Serializable {" + - " @Override public String toString() { return \"" + toStringValue + "\"; }}") // Calling this outputs a class file in pwd. It's easier to just rename the file than // build a custom FileManager that controls the output location. @@ -144,4 +139,18 @@ private[spark] object TestUtils { assert(out.exists(), "Destination file not moved: " + out.getAbsolutePath()) out } + + /** Creates a compiled class with the given name. Class file will be placed in destDir. */ + def createCompiledClass( + className: String, + destDir: File, + toStringValue: String = "", + baseClass: String = null, + classpathUrls: Seq[URL] = Seq()): File = { + val extendsText = Option(baseClass).map { c => s" extends ${c}" }.getOrElse("") + val sourceFile = new JavaSourceFromString(className, + "public class " + className + extendsText + " implements java.io.Serializable {" + + " @Override public String toString() { return \"" + toStringValue + "\"; }}") + createCompiledClass(className, destDir, sourceFile, classpathUrls) + } } diff --git a/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala b/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala index 860e1a24901b6..0550f00a172ab 100644 --- a/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala +++ b/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala @@ -43,6 +43,8 @@ class LocalSparkCluster( private val localHostname = Utils.localHostName() private val masterActorSystems = ArrayBuffer[ActorSystem]() private val workerActorSystems = ArrayBuffer[ActorSystem]() + // exposed for testing + var masterWebUIPort = -1 def start(): Array[String] = { logInfo("Starting a local Spark cluster with " + numWorkers + " workers.") @@ -53,7 +55,9 @@ class LocalSparkCluster( .set("spark.shuffle.service.enabled", "false") /* Start the Master */ - val (masterSystem, masterPort, _, _) = Master.startSystemAndActor(localHostname, 0, 0, _conf) + val (masterSystem, masterPort, webUiPort, _) = + Master.startSystemAndActor(localHostname, 0, 0, _conf) + masterWebUIPort = webUiPort masterActorSystems += masterSystem val masterUrl = "spark://" + Utils.localHostNameForURI() + ":" + masterPort val masters = Array(masterUrl) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 329fa06ba8ba5..198371b70f14f 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -753,7 +753,9 @@ private[spark] object SparkSubmitUtils { * @param artifactId the artifactId of the coordinate * @param version the version of the coordinate */ - private[deploy] case class MavenCoordinate(groupId: String, artifactId: String, version: String) + private[deploy] case class MavenCoordinate(groupId: String, artifactId: String, version: String) { + override def toString: String = s"$groupId:$artifactId:$version" + } /** * Extracts maven coordinates from a comma-delimited string. Coordinates should be provided @@ -776,6 +778,10 @@ private[spark] object SparkSubmitUtils { } } + /** Path of the local Maven cache. */ + private[spark] def m2Path: File = new File(System.getProperty("user.home"), + ".m2" + File.separator + "repository" + File.separator) + /** * Extracts maven coordinates from a comma-delimited string * @param remoteRepos Comma-delimited string of remote repositories @@ -789,8 +795,7 @@ private[spark] object SparkSubmitUtils { val localM2 = new IBiblioResolver localM2.setM2compatible(true) - val m2Path = ".m2" + File.separator + "repository" + File.separator - localM2.setRoot(new File(System.getProperty("user.home"), m2Path).toURI.toString) + localM2.setRoot(m2Path.toURI.toString) localM2.setUsepoms(true) localM2.setName("local-m2-cache") cr.add(localM2) @@ -915,69 +920,72 @@ private[spark] object SparkSubmitUtils { "" } else { val sysOut = System.out - // To prevent ivy from logging to system out - System.setOut(printStream) - val artifacts = extractMavenCoordinates(coordinates) - // Default configuration name for ivy - val ivyConfName = "default" - // set ivy settings for location of cache - val ivySettings: IvySettings = new IvySettings - // Directories for caching downloads through ivy and storing the jars when maven coordinates - // are supplied to spark-submit - val alternateIvyCache = ivyPath.getOrElse("") - val packagesDirectory: File = - if (alternateIvyCache.trim.isEmpty) { - new File(ivySettings.getDefaultIvyUserDir, "jars") + try { + // To prevent ivy from logging to system out + System.setOut(printStream) + val artifacts = extractMavenCoordinates(coordinates) + // Default configuration name for ivy + val ivyConfName = "default" + // set ivy settings for location of cache + val ivySettings: IvySettings = new IvySettings + // Directories for caching downloads through ivy and storing the jars when maven coordinates + // are supplied to spark-submit + val alternateIvyCache = ivyPath.getOrElse("") + val packagesDirectory: File = + if (alternateIvyCache.trim.isEmpty) { + new File(ivySettings.getDefaultIvyUserDir, "jars") + } else { + ivySettings.setDefaultIvyUserDir(new File(alternateIvyCache)) + ivySettings.setDefaultCache(new File(alternateIvyCache, "cache")) + new File(alternateIvyCache, "jars") + } + printStream.println( + s"Ivy Default Cache set to: ${ivySettings.getDefaultCache.getAbsolutePath}") + printStream.println(s"The jars for the packages stored in: $packagesDirectory") + // create a pattern matcher + ivySettings.addMatcher(new GlobPatternMatcher) + // create the dependency resolvers + val repoResolver = createRepoResolvers(remoteRepos, ivySettings) + ivySettings.addResolver(repoResolver) + ivySettings.setDefaultResolver(repoResolver.getName) + + val ivy = Ivy.newInstance(ivySettings) + // Set resolve options to download transitive dependencies as well + val resolveOptions = new ResolveOptions + resolveOptions.setTransitive(true) + val retrieveOptions = new RetrieveOptions + // Turn downloading and logging off for testing + if (isTest) { + resolveOptions.setDownload(false) + resolveOptions.setLog(LogOptions.LOG_QUIET) + retrieveOptions.setLog(LogOptions.LOG_QUIET) } else { - ivySettings.setDefaultIvyUserDir(new File(alternateIvyCache)) - ivySettings.setDefaultCache(new File(alternateIvyCache, "cache")) - new File(alternateIvyCache, "jars") + resolveOptions.setDownload(true) } - printStream.println( - s"Ivy Default Cache set to: ${ivySettings.getDefaultCache.getAbsolutePath}") - printStream.println(s"The jars for the packages stored in: $packagesDirectory") - // create a pattern matcher - ivySettings.addMatcher(new GlobPatternMatcher) - // create the dependency resolvers - val repoResolver = createRepoResolvers(remoteRepos, ivySettings) - ivySettings.addResolver(repoResolver) - ivySettings.setDefaultResolver(repoResolver.getName) - - val ivy = Ivy.newInstance(ivySettings) - // Set resolve options to download transitive dependencies as well - val resolveOptions = new ResolveOptions - resolveOptions.setTransitive(true) - val retrieveOptions = new RetrieveOptions - // Turn downloading and logging off for testing - if (isTest) { - resolveOptions.setDownload(false) - resolveOptions.setLog(LogOptions.LOG_QUIET) - retrieveOptions.setLog(LogOptions.LOG_QUIET) - } else { - resolveOptions.setDownload(true) - } - // A Module descriptor must be specified. Entries are dummy strings - val md = getModuleDescriptor - md.setDefaultConf(ivyConfName) + // A Module descriptor must be specified. Entries are dummy strings + val md = getModuleDescriptor + md.setDefaultConf(ivyConfName) - // Add exclusion rules for Spark and Scala Library - addExclusionRules(ivySettings, ivyConfName, md) - // add all supplied maven artifacts as dependencies - addDependenciesToIvy(md, artifacts, ivyConfName) + // Add exclusion rules for Spark and Scala Library + addExclusionRules(ivySettings, ivyConfName, md) + // add all supplied maven artifacts as dependencies + addDependenciesToIvy(md, artifacts, ivyConfName) - // resolve dependencies - val rr: ResolveReport = ivy.resolve(md, resolveOptions) - if (rr.hasError) { - throw new RuntimeException(rr.getAllProblemMessages.toString) + // resolve dependencies + val rr: ResolveReport = ivy.resolve(md, resolveOptions) + if (rr.hasError) { + throw new RuntimeException(rr.getAllProblemMessages.toString) + } + // retrieve all resolved dependencies + ivy.retrieve(rr.getModuleDescriptor.getModuleRevisionId, + packagesDirectory.getAbsolutePath + File.separator + + "[organization]_[artifact]-[revision].[ext]", + retrieveOptions.setConfs(Array(ivyConfName))) + resolveDependencyPaths(rr.getArtifacts.toArray, packagesDirectory) + } finally { + System.setOut(sysOut) } - // retrieve all resolved dependencies - ivy.retrieve(rr.getModuleDescriptor.getModuleRevisionId, - packagesDirectory.getAbsolutePath + File.separator + - "[organization]_[artifact]-[revision].[ext]", - retrieveOptions.setConfs(Array(ivyConfName))) - System.setOut(sysOut) - resolveDependencyPaths(rr.getArtifacts.toArray, packagesDirectory) } } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index f107148f3b8c6..c5bc6294a5577 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -69,6 +69,11 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp class DriverEndpoint(override val rpcEnv: RpcEnv, sparkProperties: Seq[(String, String)]) extends ThreadSafeRpcEndpoint with Logging { + // If this DriverEndpoint is changed to support multiple threads, + // then this may need to be changed so that we don't share the serializer + // instance across threads + private val ser = SparkEnv.get.closureSerializer.newInstance() + override protected def log = CoarseGrainedSchedulerBackend.this.log private val addressToExecutorId = new HashMap[RpcAddress, String] @@ -163,7 +168,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp } // Make fake resource offers on all executors - def makeOffers() { + private def makeOffers() { launchTasks(scheduler.resourceOffers(executorDataMap.map { case (id, executorData) => new WorkerOffer(id, executorData.executorHost, executorData.freeCores) }.toSeq)) @@ -175,16 +180,15 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp } // Make fake resource offers on just one executor - def makeOffers(executorId: String) { + private def makeOffers(executorId: String) { val executorData = executorDataMap(executorId) launchTasks(scheduler.resourceOffers( Seq(new WorkerOffer(executorId, executorData.executorHost, executorData.freeCores)))) } // Launch tasks returned by a set of resource offers - def launchTasks(tasks: Seq[Seq[TaskDescription]]) { + private def launchTasks(tasks: Seq[Seq[TaskDescription]]) { for (task <- tasks.flatten) { - val ser = SparkEnv.get.closureSerializer.newInstance() val serializedTask = ser.serialize(task) if (serializedTask.limit >= akkaFrameSize - AkkaUtils.reservedSizeBytes) { val taskSetId = scheduler.taskIdToTaskSetId(task.taskId) diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index 64ba27f34d2f1..217957963437d 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -177,6 +177,7 @@ private[spark] class KryoSerializerInstance(ks: KryoSerializer) extends Serializ override def serialize[T: ClassTag](t: T): ByteBuffer = { output.clear() + kryo.reset() // We must reset in case this serializer instance was reused (see SPARK-7766) try { kryo.writeClassAndObject(output, t) } catch { @@ -202,6 +203,7 @@ private[spark] class KryoSerializerInstance(ks: KryoSerializer) extends Serializ } override def serializeStream(s: OutputStream): SerializationStream = { + kryo.reset() // We must reset in case this serializer instance was reused (see SPARK-7766) new KryoSerializationStream(kryo, s) } diff --git a/core/src/main/scala/org/apache/spark/ui/WebUI.scala b/core/src/main/scala/org/apache/spark/ui/WebUI.scala index 1df9cd0fa18b4..594df15e9cc85 100644 --- a/core/src/main/scala/org/apache/spark/ui/WebUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/WebUI.scala @@ -77,7 +77,10 @@ private[spark] abstract class WebUI( val pagePath = "/" + page.prefix val renderHandler = createServletHandler(pagePath, (request: HttpServletRequest) => page.render(request), securityManager, basePath) + val renderJsonHandler = createServletHandler(pagePath.stripSuffix("/") + "/json", + (request: HttpServletRequest) => page.renderJson(request), securityManager, basePath) attachHandler(renderHandler) + attachHandler(renderJsonHandler) pageToHandlers.getOrElseUpdate(page, ArrayBuffer[ServletContextHandler]()) .append(renderHandler) } diff --git a/core/src/test/scala/org/apache/spark/deploy/IvyTestUtils.scala b/core/src/test/scala/org/apache/spark/deploy/IvyTestUtils.scala new file mode 100644 index 0000000000000..7d39984424842 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/deploy/IvyTestUtils.scala @@ -0,0 +1,261 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy + +import java.io.{File, FileInputStream, FileOutputStream} +import java.util.jar.{JarEntry, JarOutputStream} + +import com.google.common.io.{Files, ByteStreams} + +import org.apache.commons.io.FileUtils + +import org.apache.spark.TestUtils.{createCompiledClass, JavaSourceFromString} +import org.apache.spark.deploy.SparkSubmitUtils.MavenCoordinate + +private[deploy] object IvyTestUtils { + + /** + * Create the path for the jar and pom from the maven coordinate. Extension should be `jar` + * or `pom`. + */ + private def pathFromCoordinate( + artifact: MavenCoordinate, + prefix: File, + ext: String, + useIvyLayout: Boolean): File = { + val groupDirs = artifact.groupId.replace(".", File.separator) + val artifactDirs = artifact.artifactId + val artifactPath = + if (!useIvyLayout) { + Seq(groupDirs, artifactDirs, artifact.version).mkString(File.separator) + } else { + Seq(groupDirs, artifactDirs, artifact.version, ext + "s").mkString(File.separator) + } + new File(prefix, artifactPath) + } + + private def artifactName(artifact: MavenCoordinate, ext: String = ".jar"): String = { + s"${artifact.artifactId}-${artifact.version}$ext" + } + + /** Write the contents to a file to the supplied directory. */ + private def writeFile(dir: File, fileName: String, contents: String): File = { + val outputFile = new File(dir, fileName) + val outputStream = new FileOutputStream(outputFile) + outputStream.write(contents.toCharArray.map(_.toByte)) + outputStream.close() + outputFile + } + + /** Create an example Python file. */ + private def createPythonFile(dir: File): File = { + val contents = + """def myfunc(x): + | return x + 1 + """.stripMargin + writeFile(dir, "mylib.py", contents) + } + + /** Create a simple testable Class. */ + private def createJavaClass(dir: File, className: String, packageName: String): File = { + val contents = + s"""package $packageName; + | + |import java.lang.Integer; + | + |class $className implements java.io.Serializable { + | + | public $className() {} + | + | public Integer myFunc(Integer x) { + | return x + 1; + | } + |} + """.stripMargin + val sourceFile = + new JavaSourceFromString(new File(dir, className + ".java").getAbsolutePath, contents) + createCompiledClass(className, dir, sourceFile, Seq.empty) + } + + /** Helper method to write artifact information in the pom. */ + private def pomArtifactWriter(artifact: MavenCoordinate, tabCount: Int = 1): String = { + var result = "\n" + " " * tabCount + s"${artifact.groupId}" + result += "\n" + " " * tabCount + s"${artifact.artifactId}" + result += "\n" + " " * tabCount + s"${artifact.version}" + result + } + + /** Create a pom file for this artifact. */ + private def createPom( + dir: File, + artifact: MavenCoordinate, + dependencies: Option[Seq[MavenCoordinate]]): File = { + var content = """ + | + | + | 4.0.0 + """.stripMargin.trim + content += pomArtifactWriter(artifact) + content += dependencies.map { deps => + val inside = deps.map { dep => + "\t" + pomArtifactWriter(dep, 3) + "\n\t" + }.mkString("\n") + "\n \n" + inside + "\n " + }.getOrElse("") + content += "\n" + writeFile(dir, artifactName(artifact, ".pom"), content.trim) + } + + /** Create the jar for the given maven coordinate, using the supplied files. */ + private def packJar( + dir: File, + artifact: MavenCoordinate, + files: Seq[(String, File)]): File = { + val jarFile = new File(dir, artifactName(artifact)) + val jarFileStream = new FileOutputStream(jarFile) + val jarStream = new JarOutputStream(jarFileStream, new java.util.jar.Manifest()) + + for (file <- files) { + val jarEntry = new JarEntry(file._1) + jarStream.putNextEntry(jarEntry) + + val in = new FileInputStream(file._2) + ByteStreams.copy(in, jarStream) + in.close() + } + jarStream.close() + jarFileStream.close() + + jarFile + } + + /** + * Creates a jar and pom file, mocking a Maven repository. The root path can be supplied with + * `tempDir`, dependencies can be created into the same repo, and python files can also be packed + * inside the jar. + * + * @param artifact The maven coordinate to generate the jar and pom for. + * @param dependencies List of dependencies this artifact might have to also create jars and poms. + * @param tempDir The root folder of the repository + * @param useIvyLayout whether to mock the Ivy layout for local repository testing + * @param withPython Whether to pack python files inside the jar for extensive testing. + * @return Root path of the repository + */ + private def createLocalRepository( + artifact: MavenCoordinate, + dependencies: Option[Seq[MavenCoordinate]] = None, + tempDir: Option[File] = None, + useIvyLayout: Boolean = false, + withPython: Boolean = false): File = { + // Where the root of the repository exists, and what Ivy will search in + val tempPath = tempDir.getOrElse(Files.createTempDir()) + // Create directory if it doesn't exist + Files.createParentDirs(tempPath) + // Where to create temporary class files and such + val root = new File(tempPath, tempPath.hashCode().toString) + Files.createParentDirs(new File(root, "dummy")) + try { + val jarPath = pathFromCoordinate(artifact, tempPath, "jar", useIvyLayout) + Files.createParentDirs(new File(jarPath, "dummy")) + val className = "MyLib" + + val javaClass = createJavaClass(root, className, artifact.groupId) + // A tuple of files representation in the jar, and the file + val javaFile = (artifact.groupId.replace(".", "/") + "/" + javaClass.getName, javaClass) + val allFiles = + if (withPython) { + val pythonFile = createPythonFile(root) + Seq(javaFile, (pythonFile.getName, pythonFile)) + } else { + Seq(javaFile) + } + val jarFile = packJar(jarPath, artifact, allFiles) + assert(jarFile.exists(), "Problem creating Jar file") + val pomPath = pathFromCoordinate(artifact, tempPath, "pom", useIvyLayout) + Files.createParentDirs(new File(pomPath, "dummy")) + val pomFile = createPom(pomPath, artifact, dependencies) + assert(pomFile.exists(), "Problem creating Pom file") + } finally { + FileUtils.deleteDirectory(root) + } + tempPath + } + + /** + * Creates a suite of jars and poms, with or without dependencies, mocking a maven repository. + * @param artifact The main maven coordinate to generate the jar and pom for. + * @param dependencies List of dependencies this artifact might have to also create jars and poms. + * @param rootDir The root folder of the repository (like `~/.m2/repositories`) + * @param useIvyLayout whether to mock the Ivy layout for local repository testing + * @param withPython Whether to pack python files inside the jar for extensive testing. + * @return Root path of the repository. Will be `rootDir` if supplied. + */ + private[deploy] def createLocalRepositoryForTests( + artifact: MavenCoordinate, + dependencies: Option[String], + rootDir: Option[File], + useIvyLayout: Boolean = false, + withPython: Boolean = false): File = { + val deps = dependencies.map(SparkSubmitUtils.extractMavenCoordinates) + val mainRepo = createLocalRepository(artifact, deps, rootDir, useIvyLayout, withPython) + deps.foreach { seq => seq.foreach { dep => + createLocalRepository(dep, None, Some(mainRepo), useIvyLayout, withPython = false) + }} + mainRepo + } + + /** + * Creates a repository for a test, and cleans it up afterwards. + * + * @param artifact The main maven coordinate to generate the jar and pom for. + * @param dependencies List of dependencies this artifact might have to also create jars and poms. + * @param rootDir The root folder of the repository (like `~/.m2/repositories`) + * @param useIvyLayout whether to mock the Ivy layout for local repository testing + * @param withPython Whether to pack python files inside the jar for extensive testing. + * @return Root path of the repository. Will be `rootDir` if supplied. + */ + private[deploy] def withRepository( + artifact: MavenCoordinate, + dependencies: Option[String], + rootDir: Option[File], + useIvyLayout: Boolean = false, + withPython: Boolean = false)(f: String => Unit): Unit = { + val repo = createLocalRepositoryForTests(artifact, dependencies, rootDir, useIvyLayout, + withPython) + try { + f(repo.toURI.toString) + } finally { + // Clean up + if (repo.toString.contains(".m2") || repo.toString.contains(".ivy2")) { + FileUtils.deleteDirectory(new File(repo, + artifact.groupId.replace(".", File.separator) + File.separator + artifact.artifactId)) + dependencies.map(SparkSubmitUtils.extractMavenCoordinates).foreach { seq => + seq.foreach { dep => + FileUtils.deleteDirectory(new File(repo, + dep.artifactId.replace(".", File.separator))) + } + } + } else { + FileUtils.deleteDirectory(repo) + } + } + } +} diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 61c95419aedcf..ea9227a7e9af5 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -30,6 +30,7 @@ import org.scalatest.time.SpanSugar._ import org.apache.spark._ import org.apache.spark.deploy.SparkSubmit._ +import org.apache.spark.deploy.SparkSubmitUtils.MavenCoordinate import org.apache.spark.util.{ResetSystemProperties, Utils} // Note: this suite mixes in ResetSystemProperties because SparkSubmit.main() sets a bunch @@ -334,18 +335,23 @@ class SparkSubmitSuite extends FunSuite with Matchers with ResetSystemProperties runSparkSubmit(args) } + // SPARK-7287 ignore("includes jars passed in through --packages") { val unusedJar = TestUtils.createJarWithClasses(Seq.empty) - val packagesString = "com.databricks:spark-csv_2.10:0.1,com.databricks:spark-avro_2.10:0.1" - val args = Seq( - "--class", JarCreationTest.getClass.getName.stripSuffix("$"), - "--name", "testApp", - "--master", "local-cluster[2,1,512]", - "--packages", packagesString, - "--conf", "spark.ui.enabled=false", - unusedJar.toString, - "com.databricks.spark.csv.DefaultSource", "com.databricks.spark.avro.DefaultSource") - runSparkSubmit(args) + val main = MavenCoordinate("my.great.lib", "mylib", "0.1") + val dep = MavenCoordinate("my.great.dep", "mylib", "0.1") + IvyTestUtils.withRepository(main, Some(dep.toString), None) { repo => + val args = Seq( + "--class", JarCreationTest.getClass.getName.stripSuffix("$"), + "--name", "testApp", + "--master", "local-cluster[2,1,512]", + "--packages", Seq(main, dep).mkString(","), + "--repositories", repo, + "--conf", "spark.ui.enabled=false", + unusedJar.toString, + "my.great.lib.MyLib", "my.great.dep.MyLib") + runSparkSubmit(args) + } } test("resolves command line argument paths correctly") { diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala index da9578478bed9..088ca3cb93b49 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala @@ -17,17 +17,17 @@ package org.apache.spark.deploy -import java.io.{PrintStream, OutputStream, File} - -import org.apache.ivy.core.settings.IvySettings +import java.io.{File, PrintStream, OutputStream} import scala.collection.mutable.ArrayBuffer - import org.scalatest.{BeforeAndAfterAll, FunSuite} import org.apache.ivy.core.module.descriptor.MDArtifact +import org.apache.ivy.core.settings.IvySettings import org.apache.ivy.plugins.resolver.IBiblioResolver +import org.apache.spark.deploy.SparkSubmitUtils.MavenCoordinate + class SparkSubmitUtilsSuite extends FunSuite with BeforeAndAfterAll { private val noOpOutputStream = new OutputStream { @@ -89,7 +89,7 @@ class SparkSubmitUtilsSuite extends FunSuite with BeforeAndAfterAll { } test("ivy path works correctly") { - val ivyPath = "dummy/ivy" + val ivyPath = "dummy" + File.separator + "ivy" val md = SparkSubmitUtils.getModuleDescriptor val artifacts = for (i <- 0 until 3) yield new MDArtifact(md, s"jar-$i", "jar", "jar") var jPaths = SparkSubmitUtils.resolveDependencyPaths(artifacts.toArray, new File(ivyPath)) @@ -98,17 +98,38 @@ class SparkSubmitUtilsSuite extends FunSuite with BeforeAndAfterAll { assert(index >= 0) jPaths = jPaths.substring(index + ivyPath.length) } - // end to end - val jarPath = SparkSubmitUtils.resolveMavenCoordinates( - "com.databricks:spark-csv_2.10:0.1", None, Option(ivyPath), true) - assert(jarPath.indexOf(ivyPath) >= 0, "should use non-default ivy path") + val main = MavenCoordinate("my.awesome.lib", "mylib", "0.1") + IvyTestUtils.withRepository(main, None, None) { repo => + // end to end + val jarPath = SparkSubmitUtils.resolveMavenCoordinates(main.toString, Option(repo), + Option(ivyPath), true) + assert(jarPath.indexOf(ivyPath) >= 0, "should use non-default ivy path") + } } - ignore("search for artifact at other repositories") { - val path = SparkSubmitUtils.resolveMavenCoordinates("com.agimatec:agimatec-validation:0.9.3", - Option("https://oss.sonatype.org/content/repositories/agimatec/"), None, true) - assert(path.indexOf("agimatec-validation") >= 0, "should find package. If it doesn't, check" + - "if package still exists. If it has been removed, replace the example in this test.") + test("search for artifact at local repositories") { + val main = new MavenCoordinate("my.awesome.lib", "mylib", "0.1") + // Local M2 repository + IvyTestUtils.withRepository(main, None, Some(SparkSubmitUtils.m2Path)) { repo => + val jarPath = SparkSubmitUtils.resolveMavenCoordinates(main.toString, None, None, true) + assert(jarPath.indexOf("mylib") >= 0, "should find artifact") + } + // Local Ivy Repository + val settings = new IvySettings + val ivyLocal = new File(settings.getDefaultIvyUserDir, "local" + File.separator) + IvyTestUtils.withRepository(main, None, Some(ivyLocal), true) { repo => + val jarPath = SparkSubmitUtils.resolveMavenCoordinates(main.toString, None, None, true) + assert(jarPath.indexOf("mylib") >= 0, "should find artifact") + } + // Local ivy repository with modified home + val dummyIvyPath = "dummy" + File.separator + "ivy" + val dummyIvyLocal = new File(dummyIvyPath, "local" + File.separator) + IvyTestUtils.withRepository(main, None, Some(dummyIvyLocal), true) { repo => + val jarPath = SparkSubmitUtils.resolveMavenCoordinates(main.toString, None, + Some(dummyIvyPath), true) + assert(jarPath.indexOf("mylib") >= 0, "should find artifact") + assert(jarPath.indexOf(dummyIvyPath) >= 0, "should be in new ivy path") + } } test("dependency not found throws RuntimeException") { @@ -117,7 +138,7 @@ class SparkSubmitUtilsSuite extends FunSuite with BeforeAndAfterAll { } } - ignore("neglects Spark and Spark's dependencies") { + test("neglects Spark and Spark's dependencies") { val components = Seq("bagel_", "catalyst_", "core_", "graphx_", "hive_", "mllib_", "repl_", "sql_", "streaming_", "yarn_", "network-common_", "network-shuffle_", "network-yarn_") @@ -127,11 +148,11 @@ class SparkSubmitUtilsSuite extends FunSuite with BeforeAndAfterAll { val path = SparkSubmitUtils.resolveMavenCoordinates(coordinates, None, None, true) assert(path === "", "should return empty path") - // Should not exclude the following dependency. Will throw an error, because it doesn't exist, - // but the fact that it is checking means that it wasn't excluded. - intercept[RuntimeException] { - SparkSubmitUtils.resolveMavenCoordinates(coordinates + - ",org.apache.spark:spark-streaming-kafka-assembly_2.10:1.2.0", None, None, true) + val main = MavenCoordinate("org.apache.spark", "spark-streaming-kafka-assembly_2.10", "1.2.0") + IvyTestUtils.withRepository(main, None, None) { repo => + val files = SparkSubmitUtils.resolveMavenCoordinates(coordinates + "," + main.toString, + Some(repo), None, true) + assert(files.indexOf(main.artifactId) >= 0, "Did not return artifact") } } } diff --git a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala index 0faa8f650e5e1..f97e5ff6db31d 100644 --- a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala @@ -21,16 +21,20 @@ import java.util.Date import scala.concurrent.Await import scala.concurrent.duration._ +import scala.io.Source import scala.language.postfixOps import akka.actor.Address +import org.json4s._ +import org.json4s.jackson.JsonMethods._ import org.scalatest.{FunSuite, Matchers} +import org.scalatest.concurrent.Eventually import other.supplier.{CustomPersistenceEngine, CustomRecoveryModeFactory} -import org.apache.spark.deploy._ import org.apache.spark.{SparkConf, SparkException} +import org.apache.spark.deploy._ -class MasterSuite extends FunSuite with Matchers { +class MasterSuite extends FunSuite with Matchers with Eventually { test("toAkkaUrl") { val conf = new SparkConf(loadDefaults = false) @@ -157,4 +161,27 @@ class MasterSuite extends FunSuite with Matchers { CustomRecoveryModeFactory.instantiationAttempts should be > instantiationAttempts } + test("Master & worker web ui available") { + implicit val formats = org.json4s.DefaultFormats + val conf = new SparkConf() + val localCluster = new LocalSparkCluster(2, 2, 512, conf) + localCluster.start() + try { + eventually(timeout(5 seconds), interval(100 milliseconds)) { + val json = Source.fromURL(s"http://localhost:${localCluster.masterWebUIPort}/json") + .getLines().mkString("\n") + val JArray(workers) = (parse(json) \ "workers") + workers.size should be (2) + workers.foreach { workerSummaryJson => + val JString(workerWebUi) = workerSummaryJson \ "webuiaddress" + val workerResponse = parse(Source.fromURL(s"${workerWebUi}/json") + .getLines().mkString("\n")) + (workerResponse \ "cores").extract[Int] should be (2) + } + } + } finally { + localCluster.stop() + } + } + } diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala index c7369de24b81f..5faf108b394a1 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.serializer +import java.io.ByteArrayOutputStream + import scala.collection.mutable import scala.reflect.ClassTag @@ -60,6 +62,10 @@ class KryoSerializerSuite extends FunSuite with SharedSparkContext { val thrown3 = intercept[IllegalArgumentException](new KryoSerializer(conf4).newInstance()) assert(thrown3.getMessage.contains(kryoBufferProperty)) assert(!thrown3.getMessage.contains(kryoBufferMaxProperty)) + val conf5 = conf.clone() + conf5.set(kryoBufferProperty, "8m") + conf5.set(kryoBufferMaxProperty, "9m") + new KryoSerializer(conf5).newInstance() } test("basic types") { @@ -319,6 +325,37 @@ class KryoSerializerSuite extends FunSuite with SharedSparkContext { val ser2 = new KryoSerializer(conf).newInstance().asInstanceOf[KryoSerializerInstance] assert(!ser2.getAutoReset) } + + private def testSerializerInstanceReuse(autoReset: Boolean, referenceTracking: Boolean): Unit = { + val conf = new SparkConf(loadDefaults = false) + .set("spark.kryo.referenceTracking", referenceTracking.toString) + if (!autoReset) { + conf.set("spark.kryo.registrator", classOf[RegistratorWithoutAutoReset].getName) + } + val ser = new KryoSerializer(conf) + val serInstance = ser.newInstance().asInstanceOf[KryoSerializerInstance] + assert (serInstance.getAutoReset() === autoReset) + val obj = ("Hello", "World") + def serializeObjects(): Array[Byte] = { + val baos = new ByteArrayOutputStream() + val serStream = serInstance.serializeStream(baos) + serStream.writeObject(obj) + serStream.writeObject(obj) + serStream.close() + baos.toByteArray + } + val output1: Array[Byte] = serializeObjects() + val output2: Array[Byte] = serializeObjects() + assert (output1 === output2) + } + + // Regression test for SPARK-7766, an issue where disabling auto-reset and enabling + // reference-tracking would lead to corrupted output when serializer instances are re-used + for (referenceTracking <- Set(true, false); autoReset <- Set(true, false)) { + test(s"instance reuse with autoReset = $autoReset, referenceTracking = $referenceTracking") { + testSerializerInstanceReuse(autoReset = autoReset, referenceTracking = referenceTracking) + } + } } diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index 0c5221d10d79d..0d9126f23ccc5 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -75,10 +75,12 @@ class JsonProtocolSuite extends FunSuite { val blockManagerRemoved = SparkListenerBlockManagerRemoved(2L, BlockManagerId("Scarce", "to be counted...", 100)) val unpersistRdd = SparkListenerUnpersistRDD(12345) + val logUrlMap = Map("stderr" -> "mystderr", "stdout" -> "mystdout").toMap val applicationStart = SparkListenerApplicationStart("The winner of all", Some("appId"), 42L, "Garfield", Some("appAttempt")) + val applicationStartWithLogs = SparkListenerApplicationStart("The winner of all", Some("appId"), + 42L, "Garfield", Some("appAttempt"), Some(logUrlMap)) val applicationEnd = SparkListenerApplicationEnd(42L) - val logUrlMap = Map("stderr" -> "mystderr", "stdout" -> "mystdout").toMap val executorAdded = SparkListenerExecutorAdded(executorAddedTime, "exec1", new ExecutorInfo("Hostee.awesome.com", 11, logUrlMap)) val executorRemoved = SparkListenerExecutorRemoved(executorRemovedTime, "exec2", "test reason") @@ -97,6 +99,7 @@ class JsonProtocolSuite extends FunSuite { testEvent(blockManagerRemoved, blockManagerRemovedJsonString) testEvent(unpersistRdd, unpersistRDDJsonString) testEvent(applicationStart, applicationStartJsonString) + testEvent(applicationStartWithLogs, applicationStartJsonWithLogUrlsString) testEvent(applicationEnd, applicationEndJsonString) testEvent(executorAdded, executorAddedJsonString) testEvent(executorRemoved, executorRemovedJsonString) @@ -277,10 +280,12 @@ class JsonProtocolSuite extends FunSuite { test("SparkListenerApplicationStart backwards compatibility") { // SparkListenerApplicationStart in Spark 1.0.0 do not have an "appId" property. // SparkListenerApplicationStart pre-Spark 1.4 does not have "appAttemptId". - val applicationStart = SparkListenerApplicationStart("test", None, 1L, "user", None) + // SparkListenerApplicationStart pre-Spark 1.5 does not have "driverLogs + val applicationStart = SparkListenerApplicationStart("test", None, 1L, "user", None, None) val oldEvent = JsonProtocol.applicationStartToJson(applicationStart) .removeField({ _._1 == "App ID" }) .removeField({ _._1 == "App Attempt ID" }) + .removeField({ _._1 == "Driver Logs"}) assert(applicationStart === JsonProtocol.applicationStartFromJson(oldEvent)) } @@ -1544,6 +1549,22 @@ class JsonProtocolSuite extends FunSuite { |} """ + private val applicationStartJsonWithLogUrlsString = + """ + |{ + | "Event": "SparkListenerApplicationStart", + | "App Name": "The winner of all", + | "App ID": "appId", + | "Timestamp": 42, + | "User": "Garfield", + | "App Attempt ID": "appAttempt", + | "Driver Logs" : { + | "stderr" : "mystderr", + | "stdout" : "mystdout" + | } + |} + """ + private val applicationEndJsonString = """ |{ diff --git a/data/mllib/sample_multiclass_classification_data.txt b/data/mllib/sample_multiclass_classification_data.txt new file mode 100644 index 0000000000000..a0d7f90113919 --- /dev/null +++ b/data/mllib/sample_multiclass_classification_data.txt @@ -0,0 +1,150 @@ +1 1:-0.222222 2:0.5 3:-0.762712 4:-0.833333 +1 1:-0.555556 2:0.25 3:-0.864407 4:-0.916667 +1 1:-0.722222 2:-0.166667 3:-0.864407 4:-0.833333 +1 1:-0.722222 2:0.166667 3:-0.694915 4:-0.916667 +0 1:0.166667 2:-0.416667 3:0.457627 4:0.5 +1 1:-0.833333 3:-0.864407 4:-0.916667 +2 1:-1.32455e-07 2:-0.166667 3:0.220339 4:0.0833333 +2 1:-1.32455e-07 2:-0.333333 3:0.0169491 4:-4.03573e-08 +1 1:-0.5 2:0.75 3:-0.830508 4:-1 +0 1:0.611111 3:0.694915 4:0.416667 +0 1:0.222222 2:-0.166667 3:0.423729 4:0.583333 +1 1:-0.722222 2:-0.166667 3:-0.864407 4:-1 +1 1:-0.5 2:0.166667 3:-0.864407 4:-0.916667 +2 1:-0.222222 2:-0.333333 3:0.0508474 4:-4.03573e-08 +2 1:-0.0555556 2:-0.833333 3:0.0169491 4:-0.25 +2 1:-0.166667 2:-0.416667 3:-0.0169491 4:-0.0833333 +1 1:-0.944444 3:-0.898305 4:-0.916667 +2 1:-0.277778 2:-0.583333 3:-0.0169491 4:-0.166667 +0 1:0.111111 2:-0.333333 3:0.38983 4:0.166667 +2 1:-0.222222 2:-0.166667 3:0.0847457 4:-0.0833333 +0 1:0.166667 2:-0.333333 3:0.559322 4:0.666667 +1 1:-0.611111 2:0.0833333 3:-0.864407 4:-0.916667 +2 1:-0.333333 2:-0.583333 3:0.0169491 4:-4.03573e-08 +0 1:0.555555 2:-0.166667 3:0.661017 4:0.666667 +2 1:0.166667 3:0.186441 4:0.166667 +2 1:0.111111 2:-0.75 3:0.152542 4:-4.03573e-08 +2 1:0.166667 2:-0.25 3:0.118644 4:-4.03573e-08 +0 1:-0.0555556 2:-0.833333 3:0.355932 4:0.166667 +0 1:-0.277778 2:-0.333333 3:0.322034 4:0.583333 +2 1:-0.222222 2:-0.5 3:-0.152542 4:-0.25 +2 1:-0.111111 3:0.288136 4:0.416667 +2 1:-0.0555556 2:-0.25 3:0.186441 4:0.166667 +2 1:0.333333 2:-0.166667 3:0.355932 4:0.333333 +1 1:-0.611111 2:0.25 3:-0.898305 4:-0.833333 +0 1:0.166667 2:-0.333333 3:0.559322 4:0.75 +0 1:0.111111 2:-0.25 3:0.559322 4:0.416667 +0 1:0.833333 2:-0.166667 3:0.898305 4:0.666667 +2 1:-0.277778 2:-0.166667 3:0.186441 4:0.166667 +0 1:-0.666667 2:-0.583333 3:0.186441 4:0.333333 +1 1:-0.666667 2:-0.0833334 3:-0.830508 4:-1 +1 1:-0.166667 2:0.666667 3:-0.932203 4:-0.916667 +0 1:0.0555554 2:-0.333333 3:0.288136 4:0.416667 +1 1:-0.666667 2:-0.0833334 3:-0.830508 4:-1 +1 1:-0.833333 2:0.166667 3:-0.864407 4:-0.833333 +0 1:0.0555554 2:0.166667 3:0.491525 4:0.833333 +0 1:0.722222 2:-0.333333 3:0.728813 4:0.5 +2 1:-0.166667 2:-0.416667 3:0.0508474 4:-0.25 +2 1:0.5 3:0.254237 4:0.0833333 +0 1:0.111111 2:-0.583333 3:0.355932 4:0.5 +1 1:-0.944444 2:-0.166667 3:-0.898305 4:-0.916667 +2 1:0.277778 2:-0.25 3:0.220339 4:-4.03573e-08 +0 1:0.666667 2:-0.25 3:0.79661 4:0.416667 +0 1:0.111111 2:0.0833333 3:0.694915 4:1 +0 1:0.444444 3:0.59322 4:0.833333 +2 1:-0.0555556 2:0.166667 3:0.186441 4:0.25 +1 1:-0.833333 2:0.333333 3:-1 4:-0.916667 +1 1:-0.555556 2:0.416667 3:-0.830508 4:-0.75 +2 1:-0.333333 2:-0.5 3:0.152542 4:-0.0833333 +1 1:-1 2:-0.166667 3:-0.966102 4:-1 +1 1:-0.333333 2:0.25 3:-0.898305 4:-0.916667 +2 1:0.388889 2:-0.333333 3:0.288136 4:0.0833333 +2 1:0.277778 2:-0.166667 3:0.152542 4:0.0833333 +0 1:0.333333 2:0.0833333 3:0.59322 4:0.666667 +1 1:-0.777778 3:-0.79661 4:-0.916667 +1 1:-0.444444 2:0.416667 3:-0.830508 4:-0.916667 +0 1:0.222222 2:-0.166667 3:0.627119 4:0.75 +1 1:-0.555556 2:0.5 3:-0.79661 4:-0.916667 +1 1:-0.555556 2:0.5 3:-0.694915 4:-0.75 +2 1:-1.32455e-07 2:-0.25 3:0.254237 4:0.0833333 +1 1:-0.5 2:0.25 3:-0.830508 4:-0.916667 +0 1:0.166667 3:0.457627 4:0.833333 +2 1:0.444444 2:-0.0833334 3:0.322034 4:0.166667 +0 1:0.111111 2:0.166667 3:0.559322 4:0.916667 +1 1:-0.611111 2:0.25 3:-0.79661 4:-0.583333 +0 1:0.388889 3:0.661017 4:0.833333 +1 1:-0.722222 2:0.166667 3:-0.79661 4:-0.916667 +1 1:-0.722222 2:-0.0833334 3:-0.79661 4:-0.916667 +1 1:-0.555556 2:0.166667 3:-0.830508 4:-0.916667 +2 1:-0.666667 2:-0.666667 3:-0.220339 4:-0.25 +2 1:-0.611111 2:-0.75 3:-0.220339 4:-0.25 +2 1:0.0555554 2:-0.833333 3:0.186441 4:0.166667 +0 1:-0.166667 2:-0.416667 3:0.38983 4:0.5 +0 1:0.611111 2:0.333333 3:0.728813 4:1 +2 1:0.0555554 2:-0.25 3:0.118644 4:-4.03573e-08 +1 1:-0.666667 2:-0.166667 3:-0.864407 4:-0.916667 +1 1:-0.833333 2:-0.0833334 3:-0.830508 4:-0.916667 +0 1:0.611111 2:-0.166667 3:0.627119 4:0.25 +0 1:0.888889 2:0.5 3:0.932203 4:0.75 +2 1:0.222222 2:-0.333333 3:0.220339 4:0.166667 +1 1:-0.555556 2:0.25 3:-0.864407 4:-0.833333 +0 1:-1.32455e-07 2:-0.166667 3:0.322034 4:0.416667 +0 1:-1.32455e-07 2:-0.5 3:0.559322 4:0.0833333 +1 1:-0.611111 3:-0.932203 4:-0.916667 +1 1:-0.333333 2:0.833333 3:-0.864407 4:-0.916667 +0 1:-0.166667 2:-0.333333 3:0.38983 4:0.916667 +2 1:-0.333333 2:-0.666667 3:-0.0847458 4:-0.25 +2 1:-0.0555556 2:-0.416667 3:0.38983 4:0.25 +1 1:-0.388889 2:0.416667 3:-0.830508 4:-0.916667 +0 1:0.444444 2:-0.0833334 3:0.38983 4:0.833333 +1 1:-0.611111 2:0.333333 3:-0.864407 4:-0.916667 +0 1:0.111111 2:-0.416667 3:0.322034 4:0.416667 +0 1:0.166667 2:-0.0833334 3:0.525424 4:0.416667 +2 1:0.333333 2:-0.0833334 3:0.152542 4:0.0833333 +0 1:-0.0555556 2:-0.166667 3:0.288136 4:0.416667 +0 1:-0.166667 2:-0.416667 3:0.38983 4:0.5 +1 1:-0.611111 2:0.166667 3:-0.830508 4:-0.916667 +0 1:0.888889 2:-0.166667 3:0.728813 4:0.833333 +2 1:-0.277778 2:-0.25 3:-0.118644 4:-4.03573e-08 +2 1:-0.222222 2:-0.333333 3:0.186441 4:-4.03573e-08 +0 1:0.333333 2:-0.583333 3:0.627119 4:0.416667 +0 1:0.444444 2:-0.0833334 3:0.491525 4:0.666667 +2 1:-0.222222 2:-0.25 3:0.0847457 4:-4.03573e-08 +1 1:-0.611111 2:0.166667 3:-0.79661 4:-0.75 +2 1:-0.277778 2:-0.166667 3:0.0508474 4:-4.03573e-08 +0 1:1 2:0.5 3:0.830508 4:0.583333 +2 1:-0.333333 2:-0.666667 3:-0.0508475 4:-0.166667 +2 1:-0.277778 2:-0.416667 3:0.0847457 4:-4.03573e-08 +0 1:0.888889 2:-0.333333 3:0.932203 4:0.583333 +2 1:-0.111111 2:-0.166667 3:0.0847457 4:0.166667 +2 1:0.111111 2:-0.583333 3:0.322034 4:0.166667 +0 1:0.333333 2:0.0833333 3:0.59322 4:1 +0 1:0.222222 2:-0.166667 3:0.525424 4:0.416667 +1 1:-0.555556 2:0.5 3:-0.830508 4:-0.833333 +0 1:-0.111111 2:-0.166667 3:0.38983 4:0.416667 +0 1:0.888889 2:-0.5 3:1 4:0.833333 +1 1:-0.388889 2:0.583333 3:-0.898305 4:-0.75 +2 1:0.111111 2:0.0833333 3:0.254237 4:0.25 +0 1:0.333333 2:-0.166667 3:0.423729 4:0.833333 +1 1:-0.388889 2:0.166667 3:-0.762712 4:-0.916667 +0 1:0.333333 2:-0.0833334 3:0.559322 4:0.916667 +2 1:-0.333333 2:-0.75 3:0.0169491 4:-4.03573e-08 +1 1:-0.222222 2:1 3:-0.830508 4:-0.75 +1 1:-0.388889 2:0.583333 3:-0.762712 4:-0.75 +2 1:-0.611111 2:-1 3:-0.152542 4:-0.25 +2 1:-1.32455e-07 2:-0.333333 3:0.254237 4:-0.0833333 +2 1:-0.5 2:-0.416667 3:-0.0169491 4:0.0833333 +1 1:-0.888889 2:-0.75 3:-0.898305 4:-0.833333 +1 1:-0.666667 2:-0.0833334 3:-0.830508 4:-1 +2 1:-0.555556 2:-0.583333 3:-0.322034 4:-0.166667 +2 1:-0.166667 2:-0.5 3:0.0169491 4:-0.0833333 +1 1:-0.555556 2:0.0833333 3:-0.762712 4:-0.666667 +1 1:-0.777778 3:-0.898305 4:-0.916667 +0 1:0.388889 2:-0.166667 3:0.525424 4:0.666667 +0 1:0.222222 3:0.38983 4:0.583333 +2 1:0.333333 2:-0.0833334 3:0.254237 4:0.166667 +2 1:-0.388889 2:-0.166667 3:0.186441 4:0.166667 +0 1:-0.222222 2:-0.583333 3:0.355932 4:0.583333 +1 1:-0.611111 2:-0.166667 3:-0.79661 4:-0.916667 +1 1:-0.944444 2:-0.25 3:-0.864407 4:-0.916667 +1 1:-0.388889 2:0.166667 3:-0.830508 4:-0.75 diff --git a/dev/create-release/create-release.sh b/dev/create-release/create-release.sh index af4f00054997c..54274a83f6d66 100755 --- a/dev/create-release/create-release.sh +++ b/dev/create-release/create-release.sh @@ -228,14 +228,14 @@ if [[ ! "$@" =~ --skip-package ]]; then # We increment the Zinc port each time to avoid OOM's and other craziness if multiple builds # share the same Zinc server. - make_binary_release "hadoop1" "-Phadoop-1 -Phive -Phive-thriftserver" "3030" & - make_binary_release "hadoop1-scala2.11" "-Phadoop-1 -Phive -Dscala-2.11" "3031" & - make_binary_release "cdh4" "-Phadoop-1 -Phive -Phive-thriftserver -Dhadoop.version=2.0.0-mr1-cdh4.2.0" "3032" & - make_binary_release "hadoop2.3" "-Phadoop-2.3 -Phive -Phive-thriftserver -Pyarn" "3033" & - make_binary_release "hadoop2.4" "-Phadoop-2.4 -Phive -Phive-thriftserver -Pyarn" "3034" & - make_binary_release "mapr3" "-Pmapr3 -Phive -Phive-thriftserver" "3035" & - make_binary_release "mapr4" "-Pmapr4 -Pyarn -Phive -Phive-thriftserver" "3036" & - make_binary_release "hadoop2.4-without-hive" "-Phadoop-2.4 -Pyarn" "3037" & + make_binary_release "hadoop1" "-Psparkr -Phadoop-1 -Phive -Phive-thriftserver" "3030" & + make_binary_release "hadoop1-scala2.11" "-Psparkr -Phadoop-1 -Phive -Dscala-2.11" "3031" & + make_binary_release "cdh4" "-Psparkr -Phadoop-1 -Phive -Phive-thriftserver -Dhadoop.version=2.0.0-mr1-cdh4.2.0" "3032" & + make_binary_release "hadoop2.3" "-Psparkr -Phadoop-2.3 -Phive -Phive-thriftserver -Pyarn" "3033" & + make_binary_release "hadoop2.4" "-Psparkr -Phadoop-2.4 -Phive -Phive-thriftserver -Pyarn" "3034" & + make_binary_release "mapr3" "-Pmapr3 -Psparkr -Phive -Phive-thriftserver" "3035" & + make_binary_release "mapr4" "-Pmapr4 -Psparkr -Pyarn -Phive -Phive-thriftserver" "3036" & + make_binary_release "hadoop2.4-without-hive" "-Psparkr -Phadoop-2.4 -Pyarn" "3037" & wait rm -rf spark-$RELEASE_VERSION-bin-*/ diff --git a/dev/create-release/releaseutils.py b/dev/create-release/releaseutils.py index 26221b270394e..51ab25a6a5bd8 100755 --- a/dev/create-release/releaseutils.py +++ b/dev/create-release/releaseutils.py @@ -27,7 +27,7 @@ from jira.exceptions import JIRAError except ImportError: print "This tool requires the jira-python library" - print "Install using 'sudo pip install jira-python'" + print "Install using 'sudo pip install jira'" sys.exit(-1) try: diff --git a/dev/github_jira_sync.py b/dev/github_jira_sync.py index ff1e39664ee04..287f0ca24a7df 100755 --- a/dev/github_jira_sync.py +++ b/dev/github_jira_sync.py @@ -28,7 +28,7 @@ import jira.client except ImportError: print "This tool requires the jira-python library" - print "Install using 'sudo pip install jira-python'" + print "Install using 'sudo pip install jira'" sys.exit(-1) # User facing configs diff --git a/dev/merge_spark_pr.py b/dev/merge_spark_pr.py index 1c126f50bf095..787c5cc8e892d 100755 --- a/dev/merge_spark_pr.py +++ b/dev/merge_spark_pr.py @@ -426,7 +426,7 @@ def main(): print "JIRA_USERNAME and JIRA_PASSWORD not set" print "Exiting without trying to close the associated JIRA." else: - print "Could not find jira-python library. Run 'sudo pip install jira-python' to install." + print "Could not find jira-python library. Run 'sudo pip install jira' to install." print "Exiting without trying to close the associated JIRA." if __name__ == "__main__": diff --git a/dev/run-tests b/dev/run-tests index b444e74706b65..7dd8d31fd44e3 100755 --- a/dev/run-tests +++ b/dev/run-tests @@ -40,7 +40,7 @@ function handle_error () { { if [ -n "$AMPLAB_JENKINS_BUILD_PROFILE" ]; then if [ "$AMPLAB_JENKINS_BUILD_PROFILE" = "hadoop1.0" ]; then - export SBT_MAVEN_PROFILES_ARGS="-Phadoop-1 -Dhadoop.version=1.0.4" + export SBT_MAVEN_PROFILES_ARGS="-Phadoop-1 -Dhadoop.version=1.2.1" elif [ "$AMPLAB_JENKINS_BUILD_PROFILE" = "hadoop2.0" ]; then export SBT_MAVEN_PROFILES_ARGS="-Phadoop-1 -Dhadoop.version=2.0.0-mr1-cdh4.1.1" elif [ "$AMPLAB_JENKINS_BUILD_PROFILE" = "hadoop2.2" ]; then diff --git a/docs/_plugins/copy_api_dirs.rb b/docs/_plugins/copy_api_dirs.rb index 0ea3f8eab461b..6073b3626c45b 100644 --- a/docs/_plugins/copy_api_dirs.rb +++ b/docs/_plugins/copy_api_dirs.rb @@ -18,50 +18,52 @@ require 'fileutils' include FileUtils -if not (ENV['SKIP_API'] == '1' or ENV['SKIP_SCALADOC'] == '1') - # Build Scaladoc for Java/Scala +if not (ENV['SKIP_API'] == '1') + if not (ENV['SKIP_SCALADOC'] == '1') + # Build Scaladoc for Java/Scala - puts "Moving to project root and building API docs." - curr_dir = pwd - cd("..") + puts "Moving to project root and building API docs." + curr_dir = pwd + cd("..") - puts "Running 'build/sbt -Pkinesis-asl compile unidoc' from " + pwd + "; this may take a few minutes..." - puts `build/sbt -Pkinesis-asl compile unidoc` + puts "Running 'build/sbt -Pkinesis-asl compile unidoc' from " + pwd + "; this may take a few minutes..." + puts `build/sbt -Pkinesis-asl compile unidoc` - puts "Moving back into docs dir." - cd("docs") + puts "Moving back into docs dir." + cd("docs") - # Copy over the unified ScalaDoc for all projects to api/scala. - # This directory will be copied over to _site when `jekyll` command is run. - source = "../target/scala-2.10/unidoc" - dest = "api/scala" + # Copy over the unified ScalaDoc for all projects to api/scala. + # This directory will be copied over to _site when `jekyll` command is run. + source = "../target/scala-2.10/unidoc" + dest = "api/scala" - puts "Making directory " + dest - mkdir_p dest + puts "Making directory " + dest + mkdir_p dest - # From the rubydoc: cp_r('src', 'dest') makes src/dest, but this doesn't. - puts "cp -r " + source + "/. " + dest - cp_r(source + "/.", dest) + # From the rubydoc: cp_r('src', 'dest') makes src/dest, but this doesn't. + puts "cp -r " + source + "/. " + dest + cp_r(source + "/.", dest) - # Append custom JavaScript - js = File.readlines("./js/api-docs.js") - js_file = dest + "/lib/template.js" - File.open(js_file, 'a') { |f| f.write("\n" + js.join()) } + # Append custom JavaScript + js = File.readlines("./js/api-docs.js") + js_file = dest + "/lib/template.js" + File.open(js_file, 'a') { |f| f.write("\n" + js.join()) } - # Append custom CSS - css = File.readlines("./css/api-docs.css") - css_file = dest + "/lib/template.css" - File.open(css_file, 'a') { |f| f.write("\n" + css.join()) } + # Append custom CSS + css = File.readlines("./css/api-docs.css") + css_file = dest + "/lib/template.css" + File.open(css_file, 'a') { |f| f.write("\n" + css.join()) } - # Copy over the unified JavaDoc for all projects to api/java. - source = "../target/javaunidoc" - dest = "api/java" + # Copy over the unified JavaDoc for all projects to api/java. + source = "../target/javaunidoc" + dest = "api/java" - puts "Making directory " + dest - mkdir_p dest + puts "Making directory " + dest + mkdir_p dest - puts "cp -r " + source + "/. " + dest - cp_r(source + "/.", dest) + puts "cp -r " + source + "/. " + dest + cp_r(source + "/.", dest) + end # Build Sphinx docs for Python diff --git a/docs/api.md b/docs/api.md index 03460383335e8..45df77ac05f78 100644 --- a/docs/api.md +++ b/docs/api.md @@ -7,4 +7,5 @@ Here you can API docs for Spark and its submodules. - [Spark Scala API (Scaladoc)](api/scala/index.html) - [Spark Java API (Javadoc)](api/java/index.html) -- [Spark Python API (Epydoc)](api/python/index.html) +- [Spark Python API (Sphinx)](api/python/index.html) +- [Spark R API (Roxygen2)](api/R/index.html) diff --git a/docs/configuration.md b/docs/configuration.md index 0de824546c751..30508a617fdd8 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -1194,7 +1194,7 @@ Apart from these, the following properties are also available, and may be useful spark.dynamicAllocation.executorIdleTimeout - 600s + 60s If dynamic allocation is enabled and an executor has been idle for more than this duration, the executor will be removed. For more detail, see this @@ -1224,7 +1224,7 @@ Apart from these, the following properties are also available, and may be useful spark.dynamicAllocation.schedulerBacklogTimeout - 5s + 1s If dynamic allocation is enabled and there have been pending tasks backlogged for more than this duration, new executors will be requested. For more detail, see this diff --git a/docs/index.md b/docs/index.md index b5b016e34795e..5ef6d983c45a5 100644 --- a/docs/index.md +++ b/docs/index.md @@ -6,7 +6,7 @@ description: Apache Spark SPARK_VERSION_SHORT documentation homepage --- Apache Spark is a fast and general-purpose cluster computing system. -It provides high-level APIs in Java, Scala and Python, +It provides high-level APIs in Java, Scala, Python and R, and an optimized engine that supports general execution graphs. It also supports a rich set of higher-level tools including [Spark SQL](sql-programming-guide.html) for SQL and structured data processing, [MLlib](mllib-guide.html) for machine learning, [GraphX](graphx-programming-guide.html) for graph processing, and [Spark Streaming](streaming-programming-guide.html). @@ -20,13 +20,13 @@ Spark runs on both Windows and UNIX-like systems (e.g. Linux, Mac OS). It's easy locally on one machine --- all you need is to have `java` installed on your system `PATH`, or the `JAVA_HOME` environment variable pointing to a Java installation. -Spark runs on Java 6+ and Python 2.6+. For the Scala API, Spark {{site.SPARK_VERSION}} uses +Spark runs on Java 6+, Python 2.6+ and R 3.1+. For the Scala API, Spark {{site.SPARK_VERSION}} uses Scala {{site.SCALA_BINARY_VERSION}}. You will need to use a compatible Scala version ({{site.SCALA_BINARY_VERSION}}.x). # Running the Examples and Shell -Spark comes with several sample programs. Scala, Java and Python examples are in the +Spark comes with several sample programs. Scala, Java, Python and R examples are in the `examples/src/main` directory. To run one of the Java or Scala sample programs, use `bin/run-example [params]` in the top-level Spark directory. (Behind the scenes, this invokes the more general @@ -54,6 +54,15 @@ Example applications are also provided in Python. For example, ./bin/spark-submit examples/src/main/python/pi.py 10 +Spark also provides an experimental R API since 1.4 (only DataFrames APIs included). +To run Spark interactively in a R interpreter, use `bin/sparkR`: + + ./bin/sparkR --master local[2] + +Example applications are also provided in R. For example, + + ./bin/spark-submit examples/src/main/r/dataframe.R + # Launching on a Cluster The Spark [cluster mode overview](cluster-overview.html) explains the key concepts in running on a cluster. @@ -71,7 +80,7 @@ options for deployment: * [Quick Start](quick-start.html): a quick introduction to the Spark API; start here! * [Spark Programming Guide](programming-guide.html): detailed overview of Spark - in all supported languages (Scala, Java, Python) + in all supported languages (Scala, Java, Python, R) * Modules built on Spark: * [Spark Streaming](streaming-programming-guide.html): processing real-time data streams * [Spark SQL and DataFrames](sql-programming-guide.html): support for structured data and relational queries @@ -83,7 +92,8 @@ options for deployment: * [Spark Scala API (Scaladoc)](api/scala/index.html#org.apache.spark.package) * [Spark Java API (Javadoc)](api/java/index.html) -* [Spark Python API (Epydoc)](api/python/index.html) +* [Spark Python API (Sphinx)](api/python/index.html) +* [Spark R API (Roxygen2)](api/R/index.html) **Deployment Guides:** @@ -124,4 +134,5 @@ options for deployment: available online for free. * [Code Examples](http://spark.apache.org/examples.html): more are also available in the `examples` subfolder of Spark ([Scala]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/scala/org/apache/spark/examples), [Java]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/java/org/apache/spark/examples), - [Python]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/python)) + [Python]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/python), + [R]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/r)) diff --git a/docs/ml-ensembles.md b/docs/ml-ensembles.md new file mode 100644 index 0000000000000..9ff50e95fc479 --- /dev/null +++ b/docs/ml-ensembles.md @@ -0,0 +1,129 @@ +--- +layout: global +title: Ensembles +displayTitle: ML - Ensembles +--- + +**Table of Contents** + +* This will become a table of contents (this text will be scraped). +{:toc} + +An [ensemble method](http://en.wikipedia.org/wiki/Ensemble_learning) +is a learning algorithm which creates a model composed of a set of other base models. +The Pipelines API supports the following ensemble algorithms: [`OneVsRest`](api/scala/index.html#org.apache.spark.ml.classifier.OneVsRest) + +## OneVsRest + +[OneVsRest](http://en.wikipedia.org/wiki/Multiclass_classification#One-vs.-rest) is an example of a machine learning reduction for performing multiclass classification given a base classifier that can perform binary classification efficiently. + +`OneVsRest` is implemented as an `Estimator`. For the base classifier it takes instances of `Classifier` and creates a binary classification problem for each of the k classes. The classifier for class i is trained to predict whether the label is i or not, distinguishing class i from all other classes. + +Predictions are done by evaluating each binary classifier and the index of the most confident classifier is output as label. + +### Example + +The example below demonstrates how to load the +[Iris dataset](http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/iris.scale), parse it as a DataFrame and perform multiclass classification using `OneVsRest`. The test error is calculated to measure the algorithm accuracy. + +
+
+{% highlight scala %} +import org.apache.spark.ml.classification.{LogisticRegression, OneVsRest} +import org.apache.spark.mllib.evaluation.MulticlassMetrics +import org.apache.spark.mllib.util.MLUtils +import org.apache.spark.sql.{Row, SQLContext} + +val sqlContext = new SQLContext(sc) + +// parse data into dataframe +val data = MLUtils.loadLibSVMFile(sc, + "data/mllib/sample_multiclass_classification_data.txt") +val Array(train, test) = data.toDF().randomSplit(Array(0.7, 0.3)) + +// instantiate multiclass learner and train +val ovr = new OneVsRest().setClassifier(new LogisticRegression) + +val ovrModel = ovr.fit(train) + +// score model on test data +val predictions = ovrModel.transform(test).select("prediction", "label") +val predictionsAndLabels = predictions.map {case Row(p: Double, l: Double) => (p, l)} + +// compute confusion matrix +val metrics = new MulticlassMetrics(predictionsAndLabels) +println(metrics.confusionMatrix) + +// the Iris DataSet has three classes +val numClasses = 3 + +println("label\tfpr\n") +(0 until numClasses).foreach { index => + val label = index.toDouble + println(label + "\t" + metrics.falsePositiveRate(label)) +} +{% endhighlight %} +
+
+{% highlight java %} + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.ml.classification.LogisticRegression; +import org.apache.spark.ml.classification.OneVsRest; +import org.apache.spark.ml.classification.OneVsRestModel; +import org.apache.spark.mllib.evaluation.MulticlassMetrics; +import org.apache.spark.mllib.linalg.Matrix; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.util.MLUtils; +import org.apache.spark.rdd.RDD; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.SQLContext; + +SparkConf conf = new SparkConf().setAppName("JavaOneVsRestExample"); +JavaSparkContext jsc = new JavaSparkContext(conf); +SQLContext jsql = new SQLContext(jsc); + +RDD data = MLUtils.loadLibSVMFile(jsc.sc(), + "data/mllib/sample_multiclass_classification_data.txt"); + +DataFrame dataFrame = jsql.createDataFrame(data, LabeledPoint.class); +DataFrame[] splits = dataFrame.randomSplit(new double[]{0.7, 0.3}, 12345); +DataFrame train = splits[0]; +DataFrame test = splits[1]; + +// instantiate the One Vs Rest Classifier +OneVsRest ovr = new OneVsRest().setClassifier(new LogisticRegression()); + +// train the multiclass model +OneVsRestModel ovrModel = ovr.fit(train.cache()); + +// score the model on test data +DataFrame predictions = ovrModel + .transform(test) + .select("prediction", "label"); + +// obtain metrics +MulticlassMetrics metrics = new MulticlassMetrics(predictions); +Matrix confusionMatrix = metrics.confusionMatrix(); + +// output the Confusion Matrix +System.out.println("Confusion Matrix"); +System.out.println(confusionMatrix); + +// compute the false positive rate per label +System.out.println(); +System.out.println("label\tfpr\n"); + +// the Iris DataSet has three classes +int numClasses = 3; +for (int index = 0; index < numClasses; index++) { + double label = (double) index; + System.out.print(label); + System.out.print("\t"); + System.out.print(metrics.falsePositiveRate(label)); + System.out.println(); +} +{% endhighlight %} +
+
diff --git a/docs/ml-features.md b/docs/ml-features.md index 06f1ac196b39d..efe9b3b8edb6e 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -18,30 +18,38 @@ This section covers algorithms for working with features, roughly divided into t # Feature Extractors -## Hashing Term-Frequency (HashingTF) +## TF-IDF (HashingTF and IDF) -`HashingTF` is a `Transformer` which takes sets of terms (e.g., `String` terms can be sets of words) and converts those sets into fixed-length feature vectors. -The algorithm combines [Term Frequency (TF)](http://en.wikipedia.org/wiki/Tf%E2%80%93idf) counts with the [hashing trick](http://en.wikipedia.org/wiki/Feature_hashing) for dimensionality reduction. Please refer to the [MLlib user guide on TF-IDF](mllib-feature-extraction.html#tf-idf) for more details on Term-Frequency. +[Term Frequency-Inverse Document Frequency (TF-IDF)](http://en.wikipedia.org/wiki/Tf%E2%80%93idf) is a common text pre-processing step. In Spark ML, TF-IDF is separate into two parts: TF (+hashing) and IDF. -HashingTF is implemented in -[HashingTF](api/scala/index.html#org.apache.spark.ml.feature.HashingTF). -In the following code segment, we start with a set of sentences. We split each sentence into words using `Tokenizer`. For each sentence (bag of words), we hash it into a feature vector. This feature vector could then be passed to a learning algorithm. +**TF**: `HashingTF` is a `Transformer` which takes sets of terms and converts those sets into fixed-length feature vectors. In text processing, a "set of terms" might be a bag of words. +The algorithm combines Term Frequency (TF) counts with the [hashing trick](http://en.wikipedia.org/wiki/Feature_hashing) for dimensionality reduction. + +**IDF**: `IDF` is an `Estimator` which fits on a dataset and produces an `IDFModel`. The `IDFModel` takes feature vectors (generally created from `HashingTF`) and scales each column. Intuitively, it down-weights columns which appear frequently in a corpus. + +Please refer to the [MLlib user guide on TF-IDF](mllib-feature-extraction.html#tf-idf) for more details on Term Frequency and Inverse Document Frequency. +For API details, refer to the [HashingTF API docs](api/scala/index.html#org.apache.spark.ml.feature.HashingTF) and the [IDF API docs](api/scala/index.html#org.apache.spark.ml.feature.IDF). + +In the following code segment, we start with a set of sentences. We split each sentence into words using `Tokenizer`. For each sentence (bag of words), we use `HashingTF` to hash the sentence into a feature vector. We use `IDF` to rescale the feature vectors; this generally improves performance when using text as features. Our feature vectors could then be passed to a learning algorithm.
{% highlight scala %} -import org.apache.spark.ml.feature.{HashingTF, Tokenizer} +import org.apache.spark.ml.feature.{HashingTF, IDF, Tokenizer} -val sentenceDataFrame = sqlContext.createDataFrame(Seq( +val sentenceData = sqlContext.createDataFrame(Seq( (0, "Hi I heard about Spark"), (0, "I wish Java could use case classes"), (1, "Logistic regression models are neat") )).toDF("label", "sentence") val tokenizer = new Tokenizer().setInputCol("sentence").setOutputCol("words") -val wordsDataFrame = tokenizer.transform(sentenceDataFrame) -val hashingTF = new HashingTF().setInputCol("words").setOutputCol("features").setNumFeatures(20) -val featurized = hashingTF.transform(wordsDataFrame) -featurized.select("features", "label").take(3).foreach(println) +val wordsData = tokenizer.transform(sentenceData) +val hashingTF = new HashingTF().setInputCol("words").setOutputCol("rawFeatures").setNumFeatures(20) +val featurizedData = hashingTF.transform(wordsData) +val idf = new IDF().setInputCol("rawFeatures").setOutputCol("features") +val idfModel = idf.fit(featurizedData) +val rescaledData = idfModel.transform(featurizedData) +rescaledData.select("features", "label").take(3).foreach(println) {% endhighlight %}
@@ -51,6 +59,7 @@ import com.google.common.collect.Lists; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.ml.feature.HashingTF; +import org.apache.spark.ml.feature.IDF; import org.apache.spark.ml.feature.Tokenizer; import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.sql.DataFrame; @@ -70,16 +79,19 @@ StructType schema = new StructType(new StructField[]{ new StructField("label", DataTypes.DoubleType, false, Metadata.empty()), new StructField("sentence", DataTypes.StringType, false, Metadata.empty()) }); -DataFrame sentenceDataFrame = sqlContext.createDataFrame(jrdd, schema); +DataFrame sentenceData = sqlContext.createDataFrame(jrdd, schema); Tokenizer tokenizer = new Tokenizer().setInputCol("sentence").setOutputCol("words"); -DataFrame wordsDataFrame = tokenizer.transform(sentenceDataFrame); +DataFrame wordsData = tokenizer.transform(sentenceData); int numFeatures = 20; HashingTF hashingTF = new HashingTF() .setInputCol("words") - .setOutputCol("features") + .setOutputCol("rawFeatures") .setNumFeatures(numFeatures); -DataFrame featurized = hashingTF.transform(wordsDataFrame); -for (Row r : featurized.select("features", "label").take(3)) { +DataFrame featurizedData = hashingTF.transform(wordsData); +IDF idf = new IDF().setInputCol("rawFeatures").setOutputCol("features"); +IDFModel idfModel = idf.fit(featurizedData); +DataFrame rescaledData = idfModel.transform(featurizedData); +for (Row r : rescaledData.select("features", "label").take(3)) { Vector features = r.getAs(0); Double label = r.getDouble(1); System.out.println(features); @@ -89,19 +101,22 @@ for (Row r : featurized.select("features", "label").take(3)) {
{% highlight python %} -from pyspark.ml.feature import HashingTF, Tokenizer +from pyspark.ml.feature import HashingTF, IDF, Tokenizer -sentenceDataFrame = sqlContext.createDataFrame([ +sentenceData = sqlContext.createDataFrame([ (0, "Hi I heard about Spark"), (0, "I wish Java could use case classes"), (1, "Logistic regression models are neat") ], ["label", "sentence"]) tokenizer = Tokenizer(inputCol="sentence", outputCol="words") -wordsDataFrame = tokenizer.transform(sentenceDataFrame) -hashingTF = HashingTF(inputCol="words", outputCol="features", numFeatures=20) -featurized = hashingTF.transform(wordsDataFrame) -for features_label in featurized.select("features", "label").take(3): - print features_label +wordsData = tokenizer.transform(sentenceData) +hashingTF = HashingTF(inputCol="words", outputCol="rawFeatures", numFeatures=20) +featurizedData = hashingTF.transform(wordsData) +idf = IDF(inputCol="rawFeatures", outputCol="features") +idfModel = idf.fit(featurizedData) +rescaledData = idfModel.transform(featurizedData) +for features_label in rescaledData.select("features", "label").take(3): + print(features_label) {% endhighlight %}
@@ -267,11 +282,12 @@ sentenceDataFrame = sqlContext.createDataFrame([ tokenizer = Tokenizer(inputCol="sentence", outputCol="words") wordsDataFrame = tokenizer.transform(sentenceDataFrame) for words_label in wordsDataFrame.select("words", "label").take(3): - print words_label + print(words_label) {% endhighlight %} + ## Binarizer Binarization is the process of thresholding numerical features to binary features. As some probabilistic estimators make assumption that the input data is distributed according to [Bernoulli distribution](http://en.wikipedia.org/wiki/Bernoulli_distribution), a binarizer is useful for pre-processing the input data with continuous numerical features. @@ -352,7 +368,7 @@ binarizer = Binarizer(threshold=0.5, inputCol="feature", outputCol="binarized_fe binarizedDataFrame = binarizer.transform(continuousDataFrame) binarizedFeatures = binarizedDataFrame.select("binarized_feature") for binarized_feature, in binarizedFeatures.collect(): - print binarized_feature + print(binarized_feature) {% endhighlight %} @@ -618,5 +634,161 @@ indexedData = indexerModel.transform(data) + +## Normalizer + +`Normalizer` is a `Transformer` which transforms a dataset of `Vector` rows, normalizing each `Vector` to have unit norm. It takes parameter `p`, which specifies the [p-norm](http://en.wikipedia.org/wiki/Norm_%28mathematics%29#p-norm) used for normalization. ($p = 2$ by default.) This normalization can help standardize your input data and improve the behavior of learning algorithms. + +The following example demonstrates how to load a dataset in libsvm format and then normalize each row to have unit $L^2$ norm and unit $L^\infty$ norm. + +
+
+{% highlight scala %} +import org.apache.spark.ml.feature.Normalizer +import org.apache.spark.mllib.util.MLUtils + +val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") +val dataFrame = sqlContext.createDataFrame(data) + +// Normalize each Vector using $L^1$ norm. +val normalizer = new Normalizer() + .setInputCol("features") + .setOutputCol("normFeatures") + .setP(1.0) +val l1NormData = normalizer.transform(dataFrame) + +// Normalize each Vector using $L^\infty$ norm. +val lInfNormData = normalizer.transform(dataFrame, normalizer.p -> Double.PositiveInfinity) +{% endhighlight %} +
+ +
+{% highlight java %} +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.Normalizer; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.util.MLUtils; +import org.apache.spark.sql.DataFrame; + +JavaRDD data = + MLUtils.loadLibSVMFile(jsc.sc(), "data/mllib/sample_libsvm_data.txt").toJavaRDD(); +DataFrame dataFrame = jsql.createDataFrame(data, LabeledPoint.class); + +// Normalize each Vector using $L^1$ norm. +Normalizer normalizer = new Normalizer() + .setInputCol("features") + .setOutputCol("normFeatures") + .setP(1.0); +DataFrame l1NormData = normalizer.transform(dataFrame); + +// Normalize each Vector using $L^\infty$ norm. +DataFrame lInfNormData = + normalizer.transform(dataFrame, normalizer.p().w(Double.POSITIVE_INFINITY)); +{% endhighlight %} +
+ +
+{% highlight python %} +from pyspark.mllib.util import MLUtils +from pyspark.ml.feature import Normalizer + +data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") +dataFrame = sqlContext.createDataFrame(data) + +# Normalize each Vector using $L^1$ norm. +normalizer = Normalizer(inputCol="features", outputCol="normFeatures", p=1.0) +l1NormData = normalizer.transform(dataFrame) + +# Normalize each Vector using $L^\infty$ norm. +lInfNormData = normalizer.transform(dataFrame, {normalizer.p: float("inf")}) +{% endhighlight %} +
+
+ + +## StandardScaler + +`StandardScaler` transforms a dataset of `Vector` rows, normalizing each feature to have unit standard deviation and/or zero mean. It takes parameters: + +* `withStd`: True by default. Scales the data to unit standard deviation. +* `withMean`: False by default. Centers the data with mean before scaling. It will build a dense output, so this does not work on sparse input and will raise an exception. + +`StandardScaler` is a `Model` which can be `fit` on a dataset to produce a `StandardScalerModel`; this amounts to computing summary statistics. The model can then transform a `Vector` column in a dataset to have unit standard deviation and/or zero mean features. + +Note that if the standard deviation of a feature is zero, it will return default `0.0` value in the `Vector` for that feature. + +More details can be found in the API docs for +[StandardScaler](api/scala/index.html#org.apache.spark.ml.feature.StandardScaler) and +[StandardScalerModel](api/scala/index.html#org.apache.spark.ml.feature.StandardScalerModel). + +The following example demonstrates how to load a dataset in libsvm format and then normalize each feature to have unit standard deviation. + +
+
+{% highlight scala %} +import org.apache.spark.ml.feature.StandardScaler +import org.apache.spark.mllib.util.MLUtils + +val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") +val dataFrame = sqlContext.createDataFrame(data) +val scaler = new StandardScaler() + .setInputCol("features") + .setOutputCol("scaledFeatures") + .setWithStd(true) + .setWithMean(false) + +// Compute summary statistics by fitting the StandardScaler +val scalerModel = scaler.fit(dataFrame) + +// Normalize each feature to have unit standard deviation. +val scaledData = scalerModel.transform(dataFrame) +{% endhighlight %} +
+ +
+{% highlight java %} +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.StandardScaler; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.util.MLUtils; +import org.apache.spark.sql.DataFrame; + +JavaRDD data = + MLUtils.loadLibSVMFile(jsc.sc(), "data/mllib/sample_libsvm_data.txt").toJavaRDD(); +DataFrame dataFrame = jsql.createDataFrame(data, LabeledPoint.class); +StandardScaler scaler = new StandardScaler() + .setInputCol("features") + .setOutputCol("scaledFeatures") + .setWithStd(true) + .setWithMean(false); + +// Compute summary statistics by fitting the StandardScaler +StandardScalerModel scalerModel = scaler.fit(dataFrame); + +// Normalize each feature to have unit standard deviation. +DataFrame scaledData = scalerModel.transform(dataFrame); +{% endhighlight %} +
+ +
+{% highlight python %} +from pyspark.mllib.util import MLUtils +from pyspark.ml.feature import StandardScaler + +data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") +dataFrame = sqlContext.createDataFrame(data) +scaler = StandardScaler(inputCol="features", outputCol="scaledFeatures", + withStd=True, withMean=False) + +# Compute summary statistics by fitting the StandardScaler +scalerModel = scaler.fit(dataFrame) + +# Normalize each feature to have unit standard deviation. +scaledData = scalerModel.transform(dataFrame) +{% endhighlight %} +
+
+ + # Feature Selectors diff --git a/docs/ml-guide.md b/docs/ml-guide.md index cac705683c8bc..c5f50ed7990f1 100644 --- a/docs/ml-guide.md +++ b/docs/ml-guide.md @@ -150,11 +150,12 @@ This is useful if there are two algorithms with the `maxIter` parameter in a `Pi # Algorithm Guides -There are now several algorithms in the Pipelines API which are not in the lower-level MLlib API, so we link to documentation for them here. These algorithms are mostly feature transformers, which fit naturally into the `Transformer` abstraction in Pipelines. +There are now several algorithms in the Pipelines API which are not in the lower-level MLlib API, so we link to documentation for them here. These algorithms are mostly feature transformers, which fit naturally into the `Transformer` abstraction in Pipelines, and ensembles, which fit naturally into the `Estimator` abstraction in the Pipelines. **Pipelines API Algorithm Guides** * [Feature Extraction, Transformation, and Selection](ml-features.html) +* [Ensembles](ml-ensembles.html) # Code Examples diff --git a/docs/mllib-classification-regression.md b/docs/mllib-classification-regression.md index 8e91d62f4a907..0210950b89906 100644 --- a/docs/mllib-classification-regression.md +++ b/docs/mllib-classification-regression.md @@ -20,7 +20,7 @@ the supported algorithms for each type of problem. Binary Classificationlinear SVMs, logistic regression, decision trees, random forests, gradient-boosted trees, naive Bayes - Multiclass Classificationdecision trees, random forests, naive Bayes + Multiclass Classificationlogistic regression, decision trees, random forests, naive Bayes Regressionlinear least squares, Lasso, ridge regression, decision trees, random forests, gradient-boosted trees, isotonic regression @@ -31,7 +31,7 @@ the supported algorithms for each type of problem. More details for these methods can be found here: * [Linear models](mllib-linear-methods.html) - * [binary classification (SVMs, logistic regression)](mllib-linear-methods.html#binary-classification) + * [classification (SVMs, logistic regression)](mllib-linear-methods.html#classification) * [linear regression (least squares, Lasso, ridge)](mllib-linear-methods.html#linear-least-squares-lasso-and-ridge-regression) * [Decision trees](mllib-decision-tree.html) * [Ensembles of decision trees](mllib-ensembles.html) diff --git a/docs/programming-guide.md b/docs/programming-guide.md index 07a4d29fe7104..10f474f237bfa 100644 --- a/docs/programming-guide.md +++ b/docs/programming-guide.md @@ -98,9 +98,9 @@ to your version of HDFS. Some common HDFS version tags are listed on the [Prebuilt packages](http://spark.apache.org/downloads.html) are also available on the Spark homepage for common HDFS versions. -Finally, you need to import some Spark classes into your program. Add the following lines: +Finally, you need to import some Spark classes into your program. Add the following line: -{% highlight scala %} +{% highlight python %} from pyspark import SparkContext, SparkConf {% endhighlight %} @@ -478,7 +478,6 @@ the [Converter examples]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main for examples of using Cassandra / HBase ```InputFormat``` and ```OutputFormat``` with custom converters. - ## RDD Operations @@ -915,7 +914,8 @@ The following table lists some of the common transformations supported by Spark. RDD API doc ([Scala](api/scala/index.html#org.apache.spark.rdd.RDD), [Java](api/java/index.html?org/apache/spark/api/java/JavaRDD.html), - [Python](api/python/pyspark.html#pyspark.RDD)) + [Python](api/python/pyspark.html#pyspark.RDD), + [R](api/R/index.html)) and pair RDD functions doc ([Scala](api/scala/index.html#org.apache.spark.rdd.PairRDDFunctions), [Java](api/java/index.html?org/apache/spark/api/java/JavaPairRDD.html)) @@ -1028,7 +1028,9 @@ The following table lists some of the common actions supported by Spark. Refer t RDD API doc ([Scala](api/scala/index.html#org.apache.spark.rdd.RDD), [Java](api/java/index.html?org/apache/spark/api/java/JavaRDD.html), - [Python](api/python/pyspark.html#pyspark.RDD)) + [Python](api/python/pyspark.html#pyspark.RDD), + [R](api/R/index.html)) + and pair RDD functions doc ([Scala](api/scala/index.html#org.apache.spark.rdd.PairRDDFunctions), [Java](api/java/index.html?org/apache/spark/api/java/JavaPairRDD.html)) @@ -1212,9 +1214,11 @@ storage levels is: Compared to MEMORY_ONLY_SER, OFF_HEAP reduces garbage collection overhead and allows executors to be smaller and to share a pool of memory, making it attractive in environments with large heaps or multiple concurrent applications. Furthermore, as the RDDs reside in Tachyon, - the crash of an executor does not lead to losing the in-memory cache. In this mode, the memory + the crash of an executor does not lead to losing the in-memory cache. In this mode, the memory in Tachyon is discardable. Thus, Tachyon does not attempt to reconstruct a block that it evicts - from memory. + from memory. If you plan to use Tachyon as the off heap store, Spark is compatible with Tachyon + out-of-the-box. Please refer to this page + for the suggested version pairings. @@ -1565,7 +1569,8 @@ You can see some [example Spark programs](http://spark.apache.org/examples.html) In addition, Spark includes several samples in the `examples` directory ([Scala]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/scala/org/apache/spark/examples), [Java]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/java/org/apache/spark/examples), - [Python]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/python)). + [Python]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/python), + [R]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/r)). You can run Java and Scala examples by passing the class name to Spark's `bin/run-example` script; for instance: ./bin/run-example SparkPi @@ -1574,6 +1579,10 @@ For Python examples, use `spark-submit` instead: ./bin/spark-submit examples/src/main/python/pi.py +For R examples, use `spark-submit` instead: + + ./bin/spark-submit examples/src/main/r/dataframe.R + For help on optimizing your programs, the [configuration](configuration.html) and [tuning](tuning.html) guides provide information on best practices. They are especially important for making sure that your data is stored in memory in an efficient format. @@ -1581,4 +1590,4 @@ For help on deploying, the [cluster mode overview](cluster-overview.html) descri in distributed operation and supported cluster managers. Finally, full API documentation is available in -[Scala](api/scala/#org.apache.spark.package), [Java](api/java/) and [Python](api/python/). +[Scala](api/scala/#org.apache.spark.package), [Java](api/java/), [Python](api/python/) and [R](api/R/). diff --git a/docs/quick-start.md b/docs/quick-start.md index 81143da865cf0..bb39e4111f244 100644 --- a/docs/quick-start.md +++ b/docs/quick-start.md @@ -184,10 +184,10 @@ scala> linesWithSpark.cache() res7: spark.RDD[String] = spark.FilteredRDD@17e51082 scala> linesWithSpark.count() -res8: Long = 15 +res8: Long = 19 scala> linesWithSpark.count() -res9: Long = 15 +res9: Long = 19 {% endhighlight %} It may seem silly to use Spark to explore and cache a 100-line text file. The interesting part is @@ -202,10 +202,10 @@ a cluster, as described in the [programming guide](programming-guide.html#initia >>> linesWithSpark.cache() >>> linesWithSpark.count() -15 +19 >>> linesWithSpark.count() -15 +19 {% endhighlight %} It may seem silly to use Spark to explore and cache a 100-line text file. The interesting part is @@ -423,14 +423,14 @@ dependencies to `spark-submit` through its `--py-files` argument by packaging th We can run this application using the `bin/spark-submit` script: -{% highlight python %} +{% highlight bash %} # Use spark-submit to run your application $ YOUR_SPARK_HOME/bin/spark-submit \ --master local[4] \ SimpleApp.py ... Lines with a: 46, Lines with b: 23 -{% endhighlight python %} +{% endhighlight %} @@ -444,7 +444,8 @@ Congratulations on running your first Spark application! * Finally, Spark includes several samples in the `examples` directory ([Scala]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/scala/org/apache/spark/examples), [Java]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/java/org/apache/spark/examples), - [Python]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/python)). + [Python]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/python), + [R]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/r)). You can run them as follows: {% highlight bash %} @@ -453,4 +454,7 @@ You can run them as follows: # For Python examples, use spark-submit directly: ./bin/spark-submit examples/src/main/python/pi.py + +# For R examples, use spark-submit directly: +./bin/spark-submit examples/src/main/r/dataframe.R {% endhighlight %} diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 78b8e8ad515a0..5b41c0ee6e430 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -16,9 +16,9 @@ Spark SQL is a Spark module for structured data processing. It provides a progra A DataFrame is a distributed collection of data organized into named columns. It is conceptually equivalent to a table in a relational database or a data frame in R/Python, but with richer optimizations under the hood. DataFrames can be constructed from a wide array of sources such as: structured data files, tables in Hive, external databases, or existing RDDs. -The DataFrame API is available in [Scala](api/scala/index.html#org.apache.spark.sql.DataFrame), [Java](api/java/index.html?org/apache/spark/sql/DataFrame.html), and [Python](api/python/pyspark.sql.html#pyspark.sql.DataFrame). +The DataFrame API is available in [Scala](api/scala/index.html#org.apache.spark.sql.DataFrame), [Java](api/java/index.html?org/apache/spark/sql/DataFrame.html), [Python](api/python/pyspark.sql.html#pyspark.sql.DataFrame), and [R](api/R/index.html). -All of the examples on this page use sample data included in the Spark distribution and can be run in the `spark-shell` or the `pyspark` shell. +All of the examples on this page use sample data included in the Spark distribution and can be run in the `spark-shell`, `pyspark` shell, or `sparkR` shell. ## Starting Point: `SQLContext` @@ -64,6 +64,17 @@ from pyspark.sql import SQLContext sqlContext = SQLContext(sc) {% endhighlight %} + + +
+ +The entry point into all relational functionality in Spark is the +`SQLContext` class, or one of its decedents. To create a basic `SQLContext`, all you need is a SparkContext. + +{% highlight r %} +sqlContext <- sparkRSQL.init(sc) +{% endhighlight %} +
@@ -130,6 +141,19 @@ df.show() {% endhighlight %} + +
+{% highlight r %} +sqlContext <- SQLContext(sc) + +df <- jsonFile(sqlContext, "examples/src/main/resources/people.json") + +# Displays the content of the DataFrame to stdout +showDF(df) +{% endhighlight %} + +
+ @@ -296,6 +320,57 @@ df.groupBy("age").count().show() {% endhighlight %} + +
+{% highlight r %} +sqlContext <- sparkRSQL.init(sc) + +# Create the DataFrame +df <- jsonFile(sqlContext, "examples/src/main/resources/people.json") + +# Show the content of the DataFrame +showDF(df) +## age name +## null Michael +## 30 Andy +## 19 Justin + +# Print the schema in a tree format +printSchema(df) +## root +## |-- age: long (nullable = true) +## |-- name: string (nullable = true) + +# Select only the "name" column +showDF(select(df, "name")) +## name +## Michael +## Andy +## Justin + +# Select everybody, but increment the age by 1 +showDF(select(df, df$name, df$age + 1)) +## name (age + 1) +## Michael null +## Andy 31 +## Justin 20 + +# Select people older than 21 +showDF(where(df, df$age > 21)) +## age name +## 30 Andy + +# Count people by age +showDF(count(groupBy(df, "age"))) +## age count +## null 1 +## 19 1 +## 30 1 + +{% endhighlight %} + +
+ @@ -325,6 +400,14 @@ sqlContext = SQLContext(sc) df = sqlContext.sql("SELECT * FROM table") {% endhighlight %} + +
+{% highlight r %} +sqlContext <- sparkRSQL.init(sc) +df <- sql(sqlContext, "SELECT * FROM table") +{% endhighlight %} +
+ @@ -719,6 +802,15 @@ df.select("name", "favorite_color").save("namesAndFavColors.parquet") {% endhighlight %} + + +
+ +{% highlight r %} +df <- loadDF(sqlContext, "people.parquet") +saveDF(select(df, "name", "age"), "namesAndAges.parquet") +{% endhighlight %} +
@@ -760,6 +852,16 @@ df.select("name", "age").save("namesAndAges.parquet", "parquet") {% endhighlight %} + +
+ +{% highlight r %} + +df <- loadDF(sqlContext, "people.json", "json") +saveDF(select(df, "name", "age"), "namesAndAges.parquet", "parquet") + +{% endhighlight %} +
@@ -908,6 +1010,31 @@ for teenName in teenNames.collect(): +
+ +{% highlight r %} +# sqlContext from the previous example is used in this example. + +schemaPeople # The DataFrame from the previous example. + +# DataFrames can be saved as Parquet files, maintaining the schema information. +saveAsParquetFile(schemaPeople, "people.parquet") + +# Read in the Parquet file created above. Parquet files are self-describing so the schema is preserved. +# The result of loading a parquet file is also a DataFrame. +parquetFile <- parquetFile(sqlContext, "people.parquet") + +# Parquet files can also be registered as tables and then used in SQL statements. +registerTempTable(parquetFile, "parquetFile"); +teenagers <- sql(sqlContext, "SELECT name FROM parquetFile WHERE age >= 13 AND age <= 19") +teenNames <- map(teenagers, function(p) { paste("Name:", p$name)}) +for (teenName in collect(teenNames)) { + cat(teenName, "\n") +} +{% endhighlight %} + +
+
{% highlight sql %} @@ -1033,7 +1160,7 @@ df2 = sqlContext.createDataFrame(sc.parallelize(range(6, 11)) df2.save("data/test_table/key=2", "parquet") # Read the partitioned table -df3 = sqlContext.parquetFile("data/test_table") +df3 = sqlContext.load("data/test_table", "parquet") df3.printSchema() # The final schema consists of all 3 columns in the Parquet files together @@ -1047,6 +1174,33 @@ df3.printSchema()
+
+ +{% highlight r %} +# sqlContext from the previous example is used in this example. + +# Create a simple DataFrame, stored into a partition directory +saveDF(df1, "data/test_table/key=1", "parquet", "overwrite") + +# Create another DataFrame in a new partition directory, +# adding a new column and dropping an existing column +saveDF(df2, "data/test_table/key=2", "parquet", "overwrite") + +# Read the partitioned table +df3 <- loadDF(sqlContext, "data/test_table", "parquet") +printSchema(df3) + +# The final schema consists of all 3 columns in the Parquet files together +# with the partiioning column appeared in the partition directory paths. +# root +# |-- single: int (nullable = true) +# |-- double: int (nullable = true) +# |-- triple: int (nullable = true) +# |-- key : int (nullable = true) +{% endhighlight %} + +
+ ### Configuration @@ -1238,6 +1392,40 @@ anotherPeople = sqlContext.jsonRDD(anotherPeopleRDD) {% endhighlight %} +
+Spark SQL can automatically infer the schema of a JSON dataset and load it as a DataFrame. +This conversion can be done using one of two methods in a `SQLContext`: + +* `jsonFile` - loads data from a directory of JSON files where each line of the files is a JSON object. + +Note that the file that is offered as _jsonFile_ is not a typical JSON file. Each +line must contain a separate, self-contained valid JSON object. As a consequence, +a regular multi-line JSON file will most often fail. + +{% highlight r %} +# sc is an existing SparkContext. +sqlContext <- sparkRSQL.init(sc) + +# A JSON dataset is pointed to by path. +# The path can be either a single text file or a directory storing text files. +path <- "examples/src/main/resources/people.json" +# Create a DataFrame from the file(s) pointed to by path +people <- jsonFile(sqlContex,t path) + +# The inferred schema can be visualized using the printSchema() method. +printSchema(people) +# root +# |-- age: integer (nullable = true) +# |-- name: string (nullable = true) + +# Register this DataFrame as a table. +registerTempTable(people, "people") + +# SQL statements can be run by using the sql methods provided by `sqlContext`. +teenagers <- sql(sqlContext, "SELECT name FROM people WHERE age >= 13 AND age <= 19") +{% endhighlight %} +
+
{% highlight sql %} @@ -1314,10 +1502,7 @@ Row[] results = sqlContext.sql("FROM src SELECT key, value").collect();
When working with Hive one must construct a `HiveContext`, which inherits from `SQLContext`, and -adds support for finding tables in the MetaStore and writing queries using HiveQL. In addition to -the `sql` method a `HiveContext` also provides an `hql` methods, which allows queries to be -expressed in HiveQL. - +adds support for finding tables in the MetaStore and writing queries using HiveQL. {% highlight python %} # sc is an existing SparkContext. from pyspark.sql import HiveContext @@ -1331,6 +1516,24 @@ results = sqlContext.sql("FROM src SELECT key, value").collect() {% endhighlight %} +
+ +
+ +When working with Hive one must construct a `HiveContext`, which inherits from `SQLContext`, and +adds support for finding tables in the MetaStore and writing queries using HiveQL. +{% highlight r %} +# sc is an existing SparkContext. +sqlContext <- sparkRHive.init(sc) + +hql(sqlContext, "CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") +hql(sqlContext, "LOAD DATA LOCAL INPATH 'examples/src/main/resources/kv1.txt' INTO TABLE src") + +# Queries can be expressed in HiveQL. +results = sqlContext.sql("FROM src SELECT key, value").collect() + +{% endhighlight %} +
@@ -1430,6 +1633,16 @@ df = sqlContext.load(source="jdbc", url="jdbc:postgresql:dbserver", dbtable="sch +
+ +{% highlight r %} + +df <- loadDF(sqlContext, source="jdbc", url="jdbc:postgresql:dbserver", dbtable="schema.tablename") + +{% endhighlight %} + +
+
{% highlight sql %} @@ -2354,5 +2567,151 @@ from pyspark.sql.types import *
+
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
Data typeValue type in RAPI to access or create a data type
ByteType + integer
+ Note: Numbers will be converted to 1-byte signed integer numbers at runtime. + Please make sure that numbers are within the range of -128 to 127. +
+ "byte" +
ShortType + integer
+ Note: Numbers will be converted to 2-byte signed integer numbers at runtime. + Please make sure that numbers are within the range of -32768 to 32767. +
+ "short" +
IntegerType integer + "integer" +
LongType + integer
+ Note: Numbers will be converted to 8-byte signed integer numbers at runtime. + Please make sure that numbers are within the range of + -9223372036854775808 to 9223372036854775807. + Otherwise, please convert data to decimal.Decimal and use DecimalType. +
+ "long" +
FloatType + numeric
+ Note: Numbers will be converted to 4-byte single-precision floating + point numbers at runtime. +
+ "float" +
DoubleType numeric + "double" +
DecimalType Not supported + Not supported +
StringType character + "string" +
BinaryType raw + "binary" +
BooleanType logical + "bool" +
TimestampType POSIXct + "timestamp" +
DateType Date + "date" +
ArrayType vector or list + list(type="array", elementType=elementType, containsNull=[containsNull])
+ Note: The default value of containsNull is True. +
MapType enviroment + list(type="map", keyType=keyType, valueType=valueType, valueContainsNull=[valueContainsNull])
+ Note: The default value of valueContainsNull is True. +
StructType named list + list(type="struct", fields=fields)
+ Note: fields is a Seq of StructFields. Also, two fields with the same + name are not allowed. +
StructField The value type in R of the data type of this field + (For example, integer for a StructField with the data type IntegerType) + list(name=name, type=dataType, nullable=nullable) +
+ +
+ diff --git a/examples/src/main/python/hbase_inputformat.py b/examples/src/main/python/hbase_inputformat.py index 5b82a14fba413..c5ae5d043b8ea 100644 --- a/examples/src/main/python/hbase_inputformat.py +++ b/examples/src/main/python/hbase_inputformat.py @@ -18,6 +18,7 @@ from __future__ import print_function import sys +import json from pyspark import SparkContext @@ -27,24 +28,24 @@ hbase(main):016:0> create 'test', 'f1' 0 row(s) in 1.0430 seconds -hbase(main):017:0> put 'test', 'row1', 'f1', 'value1' +hbase(main):017:0> put 'test', 'row1', 'f1:a', 'value1' 0 row(s) in 0.0130 seconds -hbase(main):018:0> put 'test', 'row2', 'f1', 'value2' +hbase(main):018:0> put 'test', 'row1', 'f1:b', 'value2' 0 row(s) in 0.0030 seconds -hbase(main):019:0> put 'test', 'row3', 'f1', 'value3' +hbase(main):019:0> put 'test', 'row2', 'f1', 'value3' 0 row(s) in 0.0050 seconds -hbase(main):020:0> put 'test', 'row4', 'f1', 'value4' +hbase(main):020:0> put 'test', 'row3', 'f1', 'value4' 0 row(s) in 0.0110 seconds hbase(main):021:0> scan 'test' ROW COLUMN+CELL - row1 column=f1:, timestamp=1401883411986, value=value1 - row2 column=f1:, timestamp=1401883415212, value=value2 - row3 column=f1:, timestamp=1401883417858, value=value3 - row4 column=f1:, timestamp=1401883420805, value=value4 + row1 column=f1:a, timestamp=1401883411986, value=value1 + row1 column=f1:b, timestamp=1401883415212, value=value2 + row2 column=f1:, timestamp=1401883417858, value=value3 + row3 column=f1:, timestamp=1401883420805, value=value4 4 row(s) in 0.0240 seconds """ if __name__ == "__main__": @@ -64,6 +65,8 @@ table = sys.argv[2] sc = SparkContext(appName="HBaseInputFormat") + # Other options for configuring scan behavior are available. More information available at + # https://github.com/apache/hbase/blob/master/hbase-server/src/main/java/org/apache/hadoop/hbase/mapreduce/TableInputFormat.java conf = {"hbase.zookeeper.quorum": host, "hbase.mapreduce.inputtable": table} if len(sys.argv) > 3: conf = {"hbase.zookeeper.quorum": host, "zookeeper.znode.parent": sys.argv[3], @@ -78,6 +81,8 @@ keyConverter=keyConv, valueConverter=valueConv, conf=conf) + hbase_rdd = hbase_rdd.flatMapValues(lambda v: v.split("\n")).mapValues(json.loads) + output = hbase_rdd.collect() for (k, v) in output: print((k, v)) diff --git a/examples/src/main/scala/org/apache/spark/examples/HBaseTest.scala b/examples/src/main/scala/org/apache/spark/examples/HBaseTest.scala index 849887d23c9cf..95c96111c9b1f 100644 --- a/examples/src/main/scala/org/apache/spark/examples/HBaseTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/HBaseTest.scala @@ -59,5 +59,6 @@ object HBaseTest { hBaseRDD.count() sc.stop() + admin.close() } } diff --git a/examples/src/main/scala/org/apache/spark/examples/pythonconverters/HBaseConverters.scala b/examples/src/main/scala/org/apache/spark/examples/pythonconverters/HBaseConverters.scala index 273bee0a8b30f..90d48a64106c7 100644 --- a/examples/src/main/scala/org/apache/spark/examples/pythonconverters/HBaseConverters.scala +++ b/examples/src/main/scala/org/apache/spark/examples/pythonconverters/HBaseConverters.scala @@ -18,20 +18,34 @@ package org.apache.spark.examples.pythonconverters import scala.collection.JavaConversions._ +import scala.util.parsing.json.JSONObject import org.apache.spark.api.python.Converter import org.apache.hadoop.hbase.client.{Put, Result} import org.apache.hadoop.hbase.io.ImmutableBytesWritable import org.apache.hadoop.hbase.util.Bytes +import org.apache.hadoop.hbase.KeyValue.Type +import org.apache.hadoop.hbase.CellUtil /** - * Implementation of [[org.apache.spark.api.python.Converter]] that converts an - * HBase Result to a String + * Implementation of [[org.apache.spark.api.python.Converter]] that converts all + * the records in an HBase Result to a String */ class HBaseResultToStringConverter extends Converter[Any, String] { override def convert(obj: Any): String = { + import collection.JavaConverters._ val result = obj.asInstanceOf[Result] - Bytes.toStringBinary(result.value()) + val output = result.listCells.asScala.map(cell => + Map( + "row" -> Bytes.toStringBinary(CellUtil.cloneRow(cell)), + "columnFamily" -> Bytes.toStringBinary(CellUtil.cloneFamily(cell)), + "qualifier" -> Bytes.toStringBinary(CellUtil.cloneQualifier(cell)), + "timestamp" -> cell.getTimestamp.toString, + "type" -> Type.codeToType(cell.getTypeByte).toString, + "value" -> Bytes.toStringBinary(CellUtil.cloneValue(cell)) + ) + ) + output.map(JSONObject(_).toString()).mkString("\n") } } diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala index 800202e9fb86a..7dd8bfdc2a6db 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala @@ -18,6 +18,8 @@ package org.apache.spark.streaming.kinesis import java.util.UUID +import scala.util.control.NonFatal + import com.amazonaws.auth.{AWSCredentials, AWSCredentialsProvider, BasicAWSCredentials, DefaultAWSCredentialsProviderChain} import com.amazonaws.services.kinesis.clientlibrary.interfaces.{IRecordProcessor, IRecordProcessorFactory} import com.amazonaws.services.kinesis.clientlibrary.lib.worker.{InitialPositionInStream, KinesisClientLibConfiguration, Worker} @@ -98,6 +100,9 @@ private[kinesis] class KinesisReceiver( */ private var worker: Worker = null + /** Thread running the worker */ + private var workerThread: Thread = null + /** * This is called when the KinesisReceiver starts and must be non-blocking. * The KCL creates and manages the receiving/processing thread pool through Worker.run(). @@ -126,8 +131,19 @@ private[kinesis] class KinesisReceiver( } worker = new Worker(recordProcessorFactory, kinesisClientLibConfiguration) - worker.run() - + workerThread = new Thread() { + override def run(): Unit = { + try { + worker.run() + } catch { + case NonFatal(e) => + restart("Error running the KCL worker in Receiver", e) + } + } + } + workerThread.setName(s"Kinesis Receiver ${streamId}") + workerThread.setDaemon(true) + workerThread.start() logInfo(s"Started receiver with workerId $workerId") } @@ -137,10 +153,14 @@ private[kinesis] class KinesisReceiver( * The KCL will do its best to drain and checkpoint any in-flight records upon shutdown. */ override def onStop() { - if (worker != null) { - worker.shutdown() + if (workerThread != null) { + if (worker != null) { + worker.shutdown() + worker = null + } + workerThread.join() + workerThread = null logInfo(s"Stopped receiver for workerId $workerId") - worker = null } workerId = null } diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala index b114bcff92d0f..2531aebe7813c 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala @@ -63,9 +63,12 @@ object KinesisUtils { checkpointInterval: Duration, storageLevel: StorageLevel ): ReceiverInputDStream[Array[Byte]] = { - ssc.receiverStream( - new KinesisReceiver(kinesisAppName, streamName, endpointUrl, validateRegion(regionName), - initialPositionInStream, checkpointInterval, storageLevel, None)) + // Setting scope to override receiver stream's scope of "receiver stream" + ssc.withNamedScope("kinesis stream") { + ssc.receiverStream( + new KinesisReceiver(kinesisAppName, streamName, endpointUrl, validateRegion(regionName), + initialPositionInStream, checkpointInterval, storageLevel, None)) + } } /** diff --git a/make-distribution.sh b/make-distribution.sh index 8d6e91d67593f..a2b0c431fb4d0 100755 --- a/make-distribution.sh +++ b/make-distribution.sh @@ -231,6 +231,11 @@ cp -r "$SPARK_HOME/bin" "$DISTDIR" cp -r "$SPARK_HOME/python" "$DISTDIR" cp -r "$SPARK_HOME/sbin" "$DISTDIR" cp -r "$SPARK_HOME/ec2" "$DISTDIR" +# Copy SparkR if it exists +if [ -d "$SPARK_HOME"/R/lib/SparkR ]; then + mkdir -p "$DISTDIR"/R/lib + cp -r "$SPARK_HOME/R/lib/SparkR" "$DISTDIR"/R/lib +fi # Download and copy in tachyon, if requested if [ "$SPARK_TACHYON" == "true" ]; then diff --git a/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala b/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala index 7f3f3262a644f..9e16e60270141 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.DataFrame * Abstract class for estimators that fit models to data. */ @AlphaComponent -abstract class Estimator[M <: Model[M]] extends PipelineStage with Params { +abstract class Estimator[M <: Model[M]] extends PipelineStage { /** * Fits a single model to the input data with optional parameters. diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala index fac54188f9f4e..43bee1b770e67 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala @@ -170,7 +170,7 @@ class Pipeline(override val uid: String) extends Estimator[PipelineModel] { /** * :: AlphaComponent :: - * Represents a compiled pipeline. + * Represents a fitted pipeline. */ @AlphaComponent class PipelineModel private[ml] ( diff --git a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala index d96b54e511e9c..38bb6a5a5391e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.types._ * Abstract class for transformers that transform one dataset into another. */ @AlphaComponent -abstract class Transformer extends PipelineStage with Params { +abstract class Transformer extends PipelineStage { /** * Transforms the dataset with optional parameters diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala index c1af09c9694ba..ddbdd00ceb159 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala @@ -18,7 +18,7 @@ package org.apache.spark.ml.evaluation import org.apache.spark.annotation.AlphaComponent -import org.apache.spark.ml.Evaluator +import org.apache.spark.ml.evaluation.Evaluator import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util.{Identifiable, SchemaUtils} diff --git a/mllib/src/main/scala/org/apache/spark/ml/Evaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/Evaluator.scala similarity index 93% rename from mllib/src/main/scala/org/apache/spark/ml/Evaluator.scala rename to mllib/src/main/scala/org/apache/spark/ml/evaluation/Evaluator.scala index 5f2f8c94e9ff7..cabd1c97c085c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Evaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/Evaluator.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.ml +package org.apache.spark.ml.evaluation import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml.param.{ParamMap, Params} @@ -29,7 +29,7 @@ import org.apache.spark.sql.DataFrame abstract class Evaluator extends Params { /** - * Evaluates the output. + * Evaluates model output and returns a scalar metric (larger is better). * * @param dataset a dataset that contains labels/observations and predictions. * @param paramMap parameter map that specifies the input columns and output metrics diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala new file mode 100644 index 0000000000000..80458928c5439 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.evaluation + +import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.ml.param.{Param, ParamValidators} +import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol} +import org.apache.spark.ml.util.{Identifiable, SchemaUtils} +import org.apache.spark.mllib.evaluation.RegressionMetrics +import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.types.DoubleType + +/** + * :: AlphaComponent :: + * + * Evaluator for regression, which expects two input columns: prediction and label. + */ +@AlphaComponent +final class RegressionEvaluator(override val uid: String) + extends Evaluator with HasPredictionCol with HasLabelCol { + + def this() = this(Identifiable.randomUID("regEval")) + + /** + * param for metric name in evaluation + * @group param supports mse, rmse, r2, mae as valid metric names. + */ + val metricName: Param[String] = { + val allowedParams = ParamValidators.inArray(Array("mse", "rmse", "r2", "mae")) + new Param(this, "metricName", "metric name in evaluation (mse|rmse|r2|mae)", allowedParams) + } + + /** @group getParam */ + def getMetricName: String = $(metricName) + + /** @group setParam */ + def setMetricName(value: String): this.type = set(metricName, value) + + /** @group setParam */ + def setPredictionCol(value: String): this.type = set(predictionCol, value) + + /** @group setParam */ + def setLabelCol(value: String): this.type = set(labelCol, value) + + setDefault(metricName -> "rmse") + + override def evaluate(dataset: DataFrame): Double = { + val schema = dataset.schema + SchemaUtils.checkColumnType(schema, $(predictionCol), DoubleType) + SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType) + + val predictionAndLabels = dataset.select($(predictionCol), $(labelCol)) + .map { case Row(prediction: Double, label: Double) => + (prediction, label) + } + val metrics = new RegressionMetrics(predictionAndLabels) + val metric = $(metricName) match { + case "rmse" => + metrics.rootMeanSquaredError + case "mse" => + metrics.meanSquaredError + case "r2" => + metrics.r2 + case "mae" => + metrics.meanAbsoluteError + } + metric + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala index 41564410e4965..8ddf9d6a1e138 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala @@ -75,7 +75,7 @@ class PolynomialExpansion(override val uid: String) * To handle sparsity, if c is zero, we can skip all monomials that contain it. We remember the * current index and increment it properly for sparse input. */ -object PolynomialExpansion { +private[feature] object PolynomialExpansion { private def choose(n: Int, k: Int): Int = { Range(n, n - k, -1).product / Range(k, 1, -1).product diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala index 1c0009476908c..181b62f46fce8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala @@ -78,8 +78,7 @@ class VectorAssembler(override val uid: String) } } -@AlphaComponent -object VectorAssembler { +private object VectorAssembler { private[feature] def assemble(vv: Any*): Vector = { val indices = ArrayBuilder.make[Int] diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala index 90f0be76df44f..ed032669229ce 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala @@ -37,6 +37,7 @@ private[feature] trait Word2VecBase extends Params /** * The dimension of the code that you want to transform from words. + * @group param */ final val vectorSize = new IntParam( this, "vectorSize", "the dimension of codes after transforming from words") @@ -47,6 +48,7 @@ private[feature] trait Word2VecBase extends Params /** * Number of partitions for sentences of words. + * @group param */ final val numPartitions = new IntParam( this, "numPartitions", "number of partitions for sentences of words") @@ -58,6 +60,7 @@ private[feature] trait Word2VecBase extends Params /** * The minimum number of times a token must appear to be included in the word2vec model's * vocabulary. + * @group param */ final val minCount = new IntParam(this, "minCount", "the minimum number of times a token must " + "appear to be included in the word2vec model's vocabulary") diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index 94abfcda5cf2a..12fc5b561f76e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -24,7 +24,7 @@ import scala.annotation.varargs import scala.collection.mutable import scala.collection.JavaConverters._ -import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.annotation.{DeveloperApi, AlphaComponent} import org.apache.spark.ml.util.Identifiable /** @@ -92,9 +92,11 @@ class Param[T](val parent: String, val name: String, val doc: String, val isVali } /** + * :: DeveloperApi :: * Factory methods for common validation functions for [[Param.isValid]]. * The numerical methods only support Int, Long, Float, and Double. */ +@DeveloperApi object ParamValidators { /** (private[param]) Default validation always return true */ @@ -529,11 +531,13 @@ trait Params extends Identifiable with Serializable { } /** + * :: DeveloperApi :: * Java-friendly wrapper for [[Params]]. * Java developers who need to extend [[Params]] should use this class instead. * If you need to extend a abstract class which already extends [[Params]], then that abstract * class should be Java-friendly as well. */ +@DeveloperApi abstract class JavaParams extends Params /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index 5c6ff2dda3604..e21ff94a20f54 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -22,6 +22,7 @@ import com.github.fommil.netlib.F2jBLAS import org.apache.spark.Logging import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml._ +import org.apache.spark.ml.evaluation.Evaluator import org.apache.spark.ml.param._ import org.apache.spark.ml.util.Identifiable import org.apache.spark.mllib.util.MLUtils diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/MetadataUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/util/MetadataUtils.scala index 56075c9a6b39f..2a1db90f2ca2b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/MetadataUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/MetadataUtils.scala @@ -19,18 +19,14 @@ package org.apache.spark.ml.util import scala.collection.immutable.HashMap -import org.apache.spark.annotation.Experimental import org.apache.spark.ml.attribute._ import org.apache.spark.sql.types.StructField /** - * :: Experimental :: - * * Helper utilities for tree-based algorithms */ -@Experimental -object MetadataUtils { +private[spark] object MetadataUtils { /** * Examine a schema to identify the number of classes in a label column. diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala index 11592b77eb356..7cd53c6d7ef79 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala @@ -17,15 +17,13 @@ package org.apache.spark.ml.util -import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.types.{DataType, StructField, StructType} + /** - * :: DeveloperApi :: * Utils for handling schemas. */ -@DeveloperApi -object SchemaUtils { +private[spark] object SchemaUtils { // TODO: Move the utility methods to SQL. diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java index 23463ab5fe848..da2218056307e 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java @@ -63,17 +63,22 @@ public void hashingTF() { new StructField("label", DataTypes.DoubleType, false, Metadata.empty()), new StructField("sentence", DataTypes.StringType, false, Metadata.empty()) }); - DataFrame sentenceDataFrame = jsql.createDataFrame(jrdd, schema); - Tokenizer tokenizer = new Tokenizer().setInputCol("sentence").setOutputCol("words"); - DataFrame wordsDataFrame = tokenizer.transform(sentenceDataFrame); + DataFrame sentenceData = jsql.createDataFrame(jrdd, schema); + Tokenizer tokenizer = new Tokenizer() + .setInputCol("sentence") + .setOutputCol("words"); + DataFrame wordsData = tokenizer.transform(sentenceData); int numFeatures = 20; HashingTF hashingTF = new HashingTF() .setInputCol("words") - .setOutputCol("features") + .setOutputCol("rawFeatures") .setNumFeatures(numFeatures); - DataFrame featurized = hashingTF.transform(wordsDataFrame); - for (Row r : featurized.select("features", "words", "label").take(3)) { + DataFrame featurizedData = hashingTF.transform(wordsData); + IDF idf = new IDF().setInputCol("rawFeatures").setOutputCol("features"); + IDFModel idfModel = idf.fit(featurizedData); + DataFrame rescaledData = idfModel.transform(featurizedData); + for (Row r : rescaledData.select("features", "label").take(3)) { Vector features = r.getAs(0); Assert.assertEquals(features.size(), numFeatures); } diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaNormalizerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaNormalizerSuite.java new file mode 100644 index 0000000000000..d82f3b7e8c076 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaNormalizerSuite.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature; + +import java.util.List; + +import com.google.common.collect.Lists; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.SQLContext; + +public class JavaNormalizerSuite { + private transient JavaSparkContext jsc; + private transient SQLContext jsql; + + @Before + public void setUp() { + jsc = new JavaSparkContext("local", "JavaNormalizerSuite"); + jsql = new SQLContext(jsc); + } + + @After + public void tearDown() { + jsc.stop(); + jsc = null; + } + + @Test + public void normalizer() { + // The tests are to check Java compatibility. + List points = Lists.newArrayList( + new VectorIndexerSuite.FeatureData(Vectors.dense(0.0, -2.0)), + new VectorIndexerSuite.FeatureData(Vectors.dense(1.0, 3.0)), + new VectorIndexerSuite.FeatureData(Vectors.dense(1.0, 4.0)) + ); + DataFrame dataFrame = jsql.createDataFrame(jsc.parallelize(points, 2), + VectorIndexerSuite.FeatureData.class); + Normalizer normalizer = new Normalizer() + .setInputCol("features") + .setOutputCol("normFeatures"); + + // Normalize each Vector using $L^2$ norm. + DataFrame l2NormData = normalizer.transform(dataFrame, normalizer.p().w(2)); + l2NormData.count(); + + // Normalize each Vector using $L^\infty$ norm. + DataFrame lInfNormData = + normalizer.transform(dataFrame, normalizer.p().w(Double.POSITIVE_INFINITY)); + lInfNormData.count(); + } +} diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStandardScalerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStandardScalerSuite.java new file mode 100644 index 0000000000000..74eb2733f06ef --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStandardScalerSuite.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature; + +import java.util.List; + +import com.google.common.collect.Lists; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.SQLContext; + +public class JavaStandardScalerSuite { + private transient JavaSparkContext jsc; + private transient SQLContext jsql; + + @Before + public void setUp() { + jsc = new JavaSparkContext("local", "JavaStandardScalerSuite"); + jsql = new SQLContext(jsc); + } + + @After + public void tearDown() { + jsc.stop(); + jsc = null; + } + + @Test + public void standardScaler() { + // The tests are to check Java compatibility. + List points = Lists.newArrayList( + new VectorIndexerSuite.FeatureData(Vectors.dense(0.0, -2.0)), + new VectorIndexerSuite.FeatureData(Vectors.dense(1.0, 3.0)), + new VectorIndexerSuite.FeatureData(Vectors.dense(1.0, 4.0)) + ); + DataFrame dataFrame = jsql.createDataFrame(jsc.parallelize(points, 2), + VectorIndexerSuite.FeatureData.class); + StandardScaler scaler = new StandardScaler() + .setInputCol("features") + .setOutputCol("scaledFeatures") + .setWithStd(true) + .setWithMean(false); + + // Compute summary statistics by fitting the StandardScaler + StandardScalerModel scalerModel = scaler.fit(dataFrame); + + // Normalize each feature to have unit standard deviation. + DataFrame scaledData = scalerModel.transform(dataFrame); + scaledData.count(); + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala new file mode 100644 index 0000000000000..3ea7aad5274f2 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.evaluation + +import org.scalatest.FunSuite + +import org.apache.spark.ml.regression.LinearRegression +import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} +import org.apache.spark.mllib.util.TestingUtils._ + +class RegressionEvaluatorSuite extends FunSuite with MLlibTestSparkContext { + + test("Regression Evaluator: default params") { + /** + * Here is the instruction describing how to export the test data into CSV format + * so we can validate the metrics compared with R's mmetric package. + * + * import org.apache.spark.mllib.util.LinearDataGenerator + * val data = sc.parallelize(LinearDataGenerator.generateLinearInput(6.3, + * Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 100, 42, 0.1)) + * data.map(x=> x.label + ", " + x.features(0) + ", " + x.features(1)) + * .saveAsTextFile("path") + */ + val dataset = sqlContext.createDataFrame( + sc.parallelize(LinearDataGenerator.generateLinearInput( + 6.3, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 100, 42, 0.1), 2)) + + /** + * Using the following R code to load the data, train the model and evaluate metrics. + * + * > library("glmnet") + * > library("rminer") + * > data <- read.csv("path", header=FALSE, stringsAsFactors=FALSE) + * > features <- as.matrix(data.frame(as.numeric(data$V2), as.numeric(data$V3))) + * > label <- as.numeric(data$V1) + * > model <- glmnet(features, label, family="gaussian", alpha = 0, lambda = 0) + * > rmse <- mmetric(label, predict(model, features), metric='RMSE') + * > mae <- mmetric(label, predict(model, features), metric='MAE') + * > r2 <- mmetric(label, predict(model, features), metric='R2') + */ + val trainer = new LinearRegression + val model = trainer.fit(dataset) + val predictions = model.transform(dataset) + + // default = rmse + val evaluator = new RegressionEvaluator() + assert(evaluator.evaluate(predictions) ~== 0.1019382 absTol 0.001) + + // r2 score + evaluator.setMetricName("r2") + assert(evaluator.evaluate(predictions) ~== 0.9998196 absTol 0.001) + + // mae + evaluator.setMetricName("mae") + assert(evaluator.evaluate(predictions) ~== 0.08036075 absTol 0.001) + } +} diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 03e93a2f98f9b..11b439e7875fc 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -133,7 +133,10 @@ object MimaExcludes { "org.apache.spark.sql.parquet.TestGroupWriteSupport"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.CachedData"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.CachedData$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.CacheManager") + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.CacheManager"), + // TODO: Remove the following rule once ParquetTest has been moved to src/test. + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.parquet.ParquetTest") ) ++ Seq( // SPARK-7530 Added StreamingContext.getState() ProblemFilters.exclude[MissingMethodProblem]( diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 1b87e4e98bd83..b9515a12bc573 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -324,6 +324,7 @@ object Hive { |import org.apache.spark.sql.functions._ |import org.apache.spark.sql.hive._ |import org.apache.spark.sql.hive.test.TestHive._ + |import org.apache.spark.sql.hive.test.TestHive.implicits._ |import org.apache.spark.sql.types._""".stripMargin, cleanupCommands in console := "sparkContext.stop()", // Some of our log4j jars make it impossible to submit jobs from this JVM to Hive Map/Reduce diff --git a/python/pyspark/ml/__init__.py b/python/pyspark/ml/__init__.py index da793d9db7f91..327a11b14b5aa 100644 --- a/python/pyspark/ml/__init__.py +++ b/python/pyspark/ml/__init__.py @@ -15,6 +15,6 @@ # limitations under the License. # -from pyspark.ml.pipeline import Transformer, Estimator, Model, Pipeline, PipelineModel, Evaluator +from pyspark.ml.pipeline import Transformer, Estimator, Model, Pipeline, PipelineModel -__all__ = ["Transformer", "Estimator", "Model", "Pipeline", "PipelineModel", "Evaluator"] +__all__ = ["Transformer", "Estimator", "Model", "Pipeline", "PipelineModel"] diff --git a/python/pyspark/ml/evaluation.py b/python/pyspark/ml/evaluation.py index f4655c513cae7..23c37167b3711 100644 --- a/python/pyspark/ml/evaluation.py +++ b/python/pyspark/ml/evaluation.py @@ -15,13 +15,72 @@ # limitations under the License. # -from pyspark.ml.wrapper import JavaEvaluator +from abc import abstractmethod, ABCMeta + +from pyspark.ml.wrapper import JavaWrapper from pyspark.ml.param import Param, Params -from pyspark.ml.param.shared import HasLabelCol, HasRawPredictionCol +from pyspark.ml.param.shared import HasLabelCol, HasPredictionCol, HasRawPredictionCol from pyspark.ml.util import keyword_only from pyspark.mllib.common import inherit_doc -__all__ = ['BinaryClassificationEvaluator'] +__all__ = ['Evaluator', 'BinaryClassificationEvaluator', 'RegressionEvaluator'] + + +@inherit_doc +class Evaluator(Params): + """ + Base class for evaluators that compute metrics from predictions. + """ + + __metaclass__ = ABCMeta + + @abstractmethod + def _evaluate(self, dataset): + """ + Evaluates the output. + + :param dataset: a dataset that contains labels/observations and + predictions + :return: metric + """ + raise NotImplementedError() + + def evaluate(self, dataset, params={}): + """ + Evaluates the output with optional parameters. + + :param dataset: a dataset that contains labels/observations and + predictions + :param params: an optional param map that overrides embedded + params + :return: metric + """ + if isinstance(params, dict): + if params: + return self.copy(params)._evaluate(dataset) + else: + return self._evaluate(dataset) + else: + raise ValueError("Params must be a param map but got %s." % type(params)) + + +@inherit_doc +class JavaEvaluator(Evaluator, JavaWrapper): + """ + Base class for :py:class:`Evaluator`s that wrap Java/Scala + implementations. + """ + + __metaclass__ = ABCMeta + + def _evaluate(self, dataset): + """ + Evaluates the output. + :param dataset: a dataset that contains labels/observations and predictions. + :return: evaluation metric + """ + self._transfer_params_to_java() + return self._java_obj.evaluate(dataset._jdf) @inherit_doc @@ -89,6 +148,70 @@ def setParams(self, rawPredictionCol="rawPrediction", labelCol="label", return self._set(**kwargs) +@inherit_doc +class RegressionEvaluator(JavaEvaluator, HasLabelCol, HasPredictionCol): + """ + Evaluator for Regression, which expects two input + columns: prediction and label. + + >>> scoreAndLabels = [(-28.98343821, -27.0), (20.21491975, 21.5), + ... (-25.98418959, -22.0), (30.69731842, 33.0), (74.69283752, 71.0)] + >>> dataset = sqlContext.createDataFrame(scoreAndLabels, ["raw", "label"]) + ... + >>> evaluator = RegressionEvaluator(predictionCol="raw") + >>> evaluator.evaluate(dataset) + 2.842... + >>> evaluator.evaluate(dataset, {evaluator.metricName: "r2"}) + 0.993... + >>> evaluator.evaluate(dataset, {evaluator.metricName: "mae"}) + 2.649... + """ + # a placeholder to make it appear in the generated doc + metricName = Param(Params._dummy(), "metricName", + "metric name in evaluation (mse|rmse|r2|mae)") + + @keyword_only + def __init__(self, predictionCol="prediction", labelCol="label", + metricName="rmse"): + """ + __init__(self, predictionCol="prediction", labelCol="label", \ + metricName="rmse") + """ + super(RegressionEvaluator, self).__init__() + self._java_obj = self._new_java_obj( + "org.apache.spark.ml.evaluation.RegressionEvaluator", self.uid) + #: param for metric name in evaluation (mse|rmse|r2|mae) + self.metricName = Param(self, "metricName", + "metric name in evaluation (mse|rmse|r2|mae)") + self._setDefault(predictionCol="prediction", labelCol="label", + metricName="rmse") + kwargs = self.__init__._input_kwargs + self._set(**kwargs) + + def setMetricName(self, value): + """ + Sets the value of :py:attr:`metricName`. + """ + self._paramMap[self.metricName] = value + return self + + def getMetricName(self): + """ + Gets the value of metricName or its default value. + """ + return self.getOrDefault(self.metricName) + + @keyword_only + def setParams(self, predictionCol="prediction", labelCol="label", + metricName="rmse"): + """ + setParams(self, predictionCol="prediction", labelCol="label", + metricName="rmse") + Sets params for regression evaluator. + """ + kwargs = self.setParams._input_kwargs + return self._set(**kwargs) + if __name__ == "__main__": import doctest from pyspark.context import SparkContext diff --git a/python/pyspark/ml/pipeline.py b/python/pyspark/ml/pipeline.py index 0f38e021273b0..a563024b2cdcb 100644 --- a/python/pyspark/ml/pipeline.py +++ b/python/pyspark/ml/pipeline.py @@ -219,40 +219,3 @@ def _transform(self, dataset): def copy(self, extra={}): stages = [stage.copy(extra) for stage in self.stages] return PipelineModel(stages) - - -class Evaluator(Params): - """ - Base class for evaluators that compute metrics from predictions. - """ - - __metaclass__ = ABCMeta - - @abstractmethod - def _evaluate(self, dataset): - """ - Evaluates the output. - - :param dataset: a dataset that contains labels/observations and - predictions - :return: metric - """ - raise NotImplementedError() - - def evaluate(self, dataset, params={}): - """ - Evaluates the output with optional parameters. - - :param dataset: a dataset that contains labels/observations and - predictions - :param params: an optional param map that overrides embedded - params - :return: metric - """ - if isinstance(params, dict): - if params: - return self.copy(params)._evaluate(dataset) - else: - return self._evaluate(dataset) - else: - raise ValueError("Params must be a param map but got %s." % type(params)) diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py index 4419e16184da8..7b0893e2cdadc 100644 --- a/python/pyspark/ml/wrapper.py +++ b/python/pyspark/ml/wrapper.py @@ -20,7 +20,7 @@ from pyspark import SparkContext from pyspark.sql import DataFrame from pyspark.ml.param import Params -from pyspark.ml.pipeline import Estimator, Transformer, Evaluator, Model +from pyspark.ml.pipeline import Estimator, Transformer, Model from pyspark.mllib.common import inherit_doc, _java2py, _py2java @@ -185,22 +185,3 @@ def _call_java(self, name, *args): sc = SparkContext._active_spark_context java_args = [_py2java(sc, arg) for arg in args] return _java2py(sc, m(*java_args)) - - -@inherit_doc -class JavaEvaluator(Evaluator, JavaWrapper): - """ - Base class for :py:class:`Evaluator`s that wrap Java/Scala - implementations. - """ - - __metaclass__ = ABCMeta - - def _evaluate(self, dataset): - """ - Evaluates the output. - :param dataset: a dataset that contains labels/observations and predictions. - :return: evaluation metric - """ - self._transfer_params_to_java() - return self._java_obj.evaluate(dataset._jdf) diff --git a/python/pyspark/sql/__init__.py b/python/pyspark/sql/__init__.py index 66b0bff2908b7..8fee92ae3aed5 100644 --- a/python/pyspark/sql/__init__.py +++ b/python/pyspark/sql/__init__.py @@ -18,26 +18,28 @@ """ Important classes of Spark SQL and DataFrames: - - L{SQLContext} + - :class:`pyspark.sql.SQLContext` Main entry point for :class:`DataFrame` and SQL functionality. - - L{DataFrame} + - :class:`pyspark.sql.DataFrame` A distributed collection of data grouped into named columns. - - L{Column} + - :class:`pyspark.sql.Column` A column expression in a :class:`DataFrame`. - - L{Row} + - :class:`pyspark.sql.Row` A row of data in a :class:`DataFrame`. - - L{HiveContext} + - :class:`pyspark.sql.HiveContext` Main entry point for accessing data stored in Apache Hive. - - L{GroupedData} + - :class:`pyspark.sql.GroupedData` Aggregation methods, returned by :func:`DataFrame.groupBy`. - - L{DataFrameNaFunctions} + - :class:`pyspark.sql.DataFrameNaFunctions` Methods for handling missing data (null values). - - L{DataFrameStatFunctions} + - :class:`pyspark.sql.DataFrameStatFunctions` Methods for statistics functionality. - - L{functions} + - :class:`pyspark.sql.functions` List of built-in functions available for :class:`DataFrame`. - - L{types} + - :class:`pyspark.sql.types` List of data types available. + - :class:`pyspark.sql.Window` + For working with window functions. """ from __future__ import absolute_import @@ -66,8 +68,9 @@ def deco(f): from pyspark.sql.dataframe import DataFrame, SchemaRDD, DataFrameNaFunctions, DataFrameStatFunctions from pyspark.sql.group import GroupedData from pyspark.sql.readwriter import DataFrameReader, DataFrameWriter +from pyspark.sql.window import Window, WindowSpec __all__ = [ 'SQLContext', 'HiveContext', 'DataFrame', 'GroupedData', 'Column', 'Row', - 'DataFrameNaFunctions', 'DataFrameStatFunctions' + 'DataFrameNaFunctions', 'DataFrameStatFunctions', 'Window', 'WindowSpec', ] diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py index baf1ecbd0a2fc..8dc5039f587f0 100644 --- a/python/pyspark/sql/column.py +++ b/python/pyspark/sql/column.py @@ -116,6 +116,8 @@ class Column(object): df.colName + 1 1 / df.colName + .. note:: Experimental + .. versionadded:: 1.3 """ @@ -164,8 +166,9 @@ def __init__(self, jc): @since(1.3) def getItem(self, key): - """An expression that gets an item at position `ordinal` out of a list, - or gets an item by key out of a dict. + """ + An expression that gets an item at position ``ordinal`` out of a list, + or gets an item by key out of a dict. >>> df = sc.parallelize([([1, 2], {"key": "value"})]).toDF(["l", "d"]) >>> df.select(df.l.getItem(0), df.d.getItem("key")).show() @@ -185,7 +188,8 @@ def getItem(self, key): @since(1.3) def getField(self, name): - """An expression that gets a field by name in a StructField. + """ + An expression that gets a field by name in a StructField. >>> from pyspark.sql import Row >>> df = sc.parallelize([Row(r=Row(a=1, b="b"))]).toDF() @@ -219,7 +223,7 @@ def __getattr__(self, item): @since(1.3) def substr(self, startPos, length): """ - Return a :class:`Column` which is a substring of the column + Return a :class:`Column` which is a substring of the column. :param startPos: start position (int or Column) :param length: length of the substring (int or Column) @@ -242,7 +246,8 @@ def substr(self, startPos, length): @ignore_unicode_prefix @since(1.3) def inSet(self, *cols): - """ A boolean expression that is evaluated to true if the value of this + """ + A boolean expression that is evaluated to true if the value of this expression is contained by the evaluated values of the arguments. >>> df[df.name.inSet("Bob", "Mike")].collect() @@ -268,7 +273,8 @@ def inSet(self, *cols): @since(1.3) def alias(self, *alias): - """Returns this column aliased with a new name or names (in the case of expressions that + """ + Returns this column aliased with a new name or names (in the case of expressions that return more than one column, such as explode). >>> df.select(df.age.alias("age2")).collect() @@ -284,7 +290,7 @@ def alias(self, *alias): @ignore_unicode_prefix @since(1.3) def cast(self, dataType): - """ Convert the column into type `dataType` + """ Convert the column into type ``dataType``. >>> df.select(df.age.cast("string").alias('ages')).collect() [Row(ages=u'2'), Row(ages=u'5')] @@ -304,25 +310,24 @@ def cast(self, dataType): astype = cast - @ignore_unicode_prefix @since(1.3) def between(self, lowerBound, upperBound): - """ A boolean expression that is evaluated to true if the value of this + """ + A boolean expression that is evaluated to true if the value of this expression is between the given columns. """ return (self >= lowerBound) & (self <= upperBound) - @ignore_unicode_prefix @since(1.4) def when(self, condition, value): - """Evaluates a list of conditions and returns one of multiple possible result expressions. + """ + Evaluates a list of conditions and returns one of multiple possible result expressions. If :func:`Column.otherwise` is not invoked, None is returned for unmatched conditions. See :func:`pyspark.sql.functions.when` for example usage. :param condition: a boolean :class:`Column` expression. :param value: a literal value, or a :class:`Column` expression. - """ sc = SparkContext._active_spark_context if not isinstance(condition, Column): @@ -331,10 +336,10 @@ def when(self, condition, value): jc = sc._jvm.functions.when(condition._jc, v) return Column(jc) - @ignore_unicode_prefix @since(1.4) def otherwise(self, value): - """Evaluates a list of conditions and returns one of multiple possible result expressions. + """ + Evaluates a list of conditions and returns one of multiple possible result expressions. If :func:`Column.otherwise` is not invoked, None is returned for unmatched conditions. See :func:`pyspark.sql.functions.when` for example usage. @@ -345,6 +350,27 @@ def otherwise(self, value): jc = self._jc.otherwise(value) return Column(jc) + @since(1.4) + def over(self, window): + """ + Define a windowing column. + + :param window: a :class:`WindowSpec` + :return: a Column + + >>> from pyspark.sql import Window + >>> window = Window.partitionBy("name").orderBy("age").rowsBetween(-1, 1) + >>> from pyspark.sql.functions import rank, min + >>> # df.select(rank().over(window), min('age').over(window)) + + .. note:: Window functions is only supported with HiveContext in 1.4 + """ + from pyspark.sql.window import WindowSpec + if not isinstance(window, WindowSpec): + raise TypeError("window should be WindowSpec") + jc = self._jc.over(window._jspec) + return Column(jc) + def __repr__(self): return 'Column<%s>' % self._jc.toString().encode('utf8') diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index 51f12c5bb4198..22f6257dfe02d 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -585,8 +585,6 @@ def read(self): Returns a :class:`DataFrameReader` that can be used to read data in as a :class:`DataFrame`. - .. note:: Experimental - >>> sqlContext.read """ diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 132db90e69f59..936487519a645 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -62,6 +62,8 @@ class DataFrame(object): people.filter(people.age > 30).join(department, people.deptId == department.id)) \ .groupBy(department.name, "gender").agg({"salary": "avg", "age": "max"}) + .. note:: Experimental + .. versionadded:: 1.3 """ @@ -161,7 +163,7 @@ def insertInto(self, tableName, overwrite=False): Optionally overwriting any existing data. """ - self._jdf.insertInto(tableName, overwrite) + self.write.insertInto(tableName, overwrite) @since(1.3) def saveAsTable(self, tableName, source=None, mode="error", **options): diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 9b0d7f3e6656e..bbf465aca8d4d 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -32,16 +32,21 @@ __all__ = [ + 'array', 'approxCountDistinct', 'coalesce', 'countDistinct', + 'explode', 'monotonicallyIncreasingId', 'rand', 'randn', 'sparkPartitionId', + 'struct', 'udf', 'when'] +__all__ += ['lag', 'lead', 'ntile'] + def _create_function(name, doc=""): """ Create a function for aggregator by name""" @@ -67,6 +72,17 @@ def _(col1, col2): return _ +def _create_window_function(name, doc=''): + """ Create a window function by name """ + def _(): + sc = SparkContext._active_spark_context + jc = getattr(sc._jvm.functions, name)() + return Column(jc) + _.__name__ = name + _.__doc__ = 'Window function: ' + doc + return _ + + _functions = { 'lit': 'Creates a :class:`Column` of literal value.', 'col': 'Returns a :class:`Column` based on the given column name.', @@ -130,15 +146,53 @@ def _(col1, col2): 'pow': 'Returns the value of the first argument raised to the power of the second argument.' } +_window_functions = { + 'rowNumber': + """returns a sequential number starting at 1 within a window partition. + + This is equivalent to the ROW_NUMBER function in SQL.""", + 'denseRank': + """returns the rank of rows within a window partition, without any gaps. + + The difference between rank and denseRank is that denseRank leaves no gaps in ranking + sequence when there are ties. That is, if you were ranking a competition using denseRank + and had three people tie for second place, you would say that all three were in second + place and that the next person came in third. + + This is equivalent to the DENSE_RANK function in SQL.""", + 'rank': + """returns the rank of rows within a window partition. + + The difference between rank and denseRank is that denseRank leaves no gaps in ranking + sequence when there are ties. That is, if you were ranking a competition using denseRank + and had three people tie for second place, you would say that all three were in second + place and that the next person came in third. + + This is equivalent to the RANK function in SQL.""", + 'cumeDist': + """returns the cumulative distribution of values within a window partition, + i.e. the fraction of rows that are below the current row. + + This is equivalent to the CUME_DIST function in SQL.""", + 'percentRank': + """returns the relative rank (i.e. percentile) of rows within a window partition. + + This is equivalent to the PERCENT_RANK function in SQL.""", +} + for _name, _doc in _functions.items(): globals()[_name] = since(1.3)(_create_function(_name, _doc)) for _name, _doc in _functions_1_4.items(): globals()[_name] = since(1.4)(_create_function(_name, _doc)) for _name, _doc in _binary_mathfunctions.items(): globals()[_name] = since(1.4)(_create_binary_mathfunction(_name, _doc)) +for _name, _doc in _window_functions.items(): + globals()[_name] = since(1.4)(_create_window_function(_name, _doc)) del _name, _doc __all__ += _functions.keys() +__all__ += _functions_1_4.keys() __all__ += _binary_mathfunctions.keys() +__all__ += _window_functions.keys() __all__.sort() @@ -176,27 +230,6 @@ def approxCountDistinct(col, rsd=None): return Column(jc) -@since(1.4) -def explode(col): - """Returns a new row for each element in the given array or map. - - >>> from pyspark.sql import Row - >>> eDF = sqlContext.createDataFrame([Row(a=1, intlist=[1,2,3], mapfield={"a": "b"})]) - >>> eDF.select(explode(eDF.intlist).alias("anInt")).collect() - [Row(anInt=1), Row(anInt=2), Row(anInt=3)] - - >>> eDF.select(explode(eDF.mapfield).alias("key", "value")).show() - +---+-----+ - |key|value| - +---+-----+ - | a| b| - +---+-----+ - """ - sc = SparkContext._active_spark_context - jc = sc._jvm.functions.explode(_to_java_column(col)) - return Column(jc) - - @since(1.4) def coalesce(*cols): """Returns the first column that is not null. @@ -249,6 +282,27 @@ def countDistinct(col, *cols): return Column(jc) +@since(1.4) +def explode(col): + """Returns a new row for each element in the given array or map. + + >>> from pyspark.sql import Row + >>> eDF = sqlContext.createDataFrame([Row(a=1, intlist=[1,2,3], mapfield={"a": "b"})]) + >>> eDF.select(explode(eDF.intlist).alias("anInt")).collect() + [Row(anInt=1), Row(anInt=2), Row(anInt=3)] + + >>> eDF.select(explode(eDF.mapfield).alias("key", "value")).show() + +---+-----+ + |key|value| + +---+-----+ + | a| b| + +---+-----+ + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.explode(_to_java_column(col)) + return Column(jc) + + @since(1.4) def monotonicallyIncreasingId(): """A column that generates monotonically increasing 64-bit integers. @@ -258,7 +312,7 @@ def monotonicallyIncreasingId(): within each partition in the lower 33 bits. The assumption is that the data frame has less than 1 billion partitions, and each partition has less than 8 billion records. - As an example, consider a [[DataFrame]] with two partitions, each with 3 records. + As an example, consider a :class:`DataFrame` with two partitions, each with 3 records. This expression would return the following IDs: 0, 1, 2, 8589934592 (1L << 33), 8589934593, 8589934594. @@ -349,6 +403,55 @@ def when(condition, value): return Column(jc) +@since(1.4) +def lag(col, count=1, default=None): + """ + Window function: returns the value that is `offset` rows before the current row, and + `defaultValue` if there is less than `offset` rows before the current row. For example, + an `offset` of one will return the previous row at any given point in the window partition. + + This is equivalent to the LAG function in SQL. + + :param col: name of column or expression + :param count: number of row to extend + :param default: default value + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.lag(_to_java_column(col), count, default)) + + +@since(1.4) +def lead(col, count=1, default=None): + """ + Window function: returns the value that is `offset` rows after the current row, and + `defaultValue` if there is less than `offset` rows after the current row. For example, + an `offset` of one will return the next row at any given point in the window partition. + + This is equivalent to the LEAD function in SQL. + + :param col: name of column or expression + :param count: number of row to extend + :param default: default value + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.lead(_to_java_column(col), count, default)) + + +@since(1.4) +def ntile(n): + """ + Window function: returns a group id from 1 to `n` (inclusive) in a round-robin fashion in + a window partition. Fow example, if `n` is 3, the first row will get 1, the second row will + get 2, the third row will get 3, and the fourth row will get 1... + + This is equivalent to the NTILE function in SQL. + + :param n: an integer + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.ntile(int(n))) + + class UserDefinedFunction(object): """ User defined function in Python diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py index 4da472a577eae..5a37a673ee80c 100644 --- a/python/pyspark/sql/group.py +++ b/python/pyspark/sql/group.py @@ -49,6 +49,8 @@ class GroupedData(object): A set of methods for aggregations on a :class:`DataFrame`, created by :func:`DataFrame.groupBy`. + .. note:: Experimental + .. versionadded:: 1.3 """ diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 02b3aab2b12e4..b6fd413bec7db 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -226,17 +226,25 @@ def save(self, path=None, format=None, mode="error", **options): else: jwrite.save(path) + def insertInto(self, tableName, overwrite=False): + """ + Inserts the content of the :class:`DataFrame` to the specified table. + It requires that the schema of the class:`DataFrame` is the same as the + schema of the table. + + Optionally overwriting any existing data. + """ + self._jwrite.mode("overwrite" if overwrite else "append").insertInto(tableName) + @since(1.4) def saveAsTable(self, name, format=None, mode="error", **options): """ - Saves the contents of this :class:`DataFrame` to a data source as a table. - - The data source is specified by the ``source`` and a set of ``options``. - If ``source`` is not specified, the default data source configured by - ``spark.sql.sources.default`` will be used. + Saves the content of the :class:`DataFrame` as the specified table. - Additionally, mode is used to specify the behavior of the saveAsTable operation when - table already exists in the data source. There are four modes: + In the case the table already exists, behavior of this function depends on the + save mode, specified by the `mode` function (default to throwing an exception). + When `mode` is `Overwrite`, the schema of the [[DataFrame]] does not need to be + the same as that of the existing table. * `append`: Append contents of this :class:`DataFrame` to existing data. * `overwrite`: Overwrite existing data. diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 7e349962416c9..5c53c3a8ed4f1 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -44,6 +44,7 @@ from pyspark.sql.types import UserDefinedType, _infer_type from pyspark.tests import ReusedPySparkTestCase from pyspark.sql.functions import UserDefinedFunction +from pyspark.sql.window import Window class ExamplePointUDT(UserDefinedType): @@ -743,11 +744,9 @@ def setUpClass(cls): try: cls.sc._jvm.org.apache.hadoop.hive.conf.HiveConf() except py4j.protocol.Py4JError: - cls.sqlCtx = None - return + raise unittest.SkipTest("Hive is not available") except TypeError: - cls.sqlCtx = None - return + raise unittest.SkipTest("Hive is not available") os.unlink(cls.tempdir.name) _scala_HiveContext =\ cls.sc._jvm.org.apache.spark.sql.hive.test.TestHiveContext(cls.sc._jsc.sc()) @@ -761,9 +760,6 @@ def tearDownClass(cls): shutil.rmtree(cls.tempdir.name, ignore_errors=True) def test_save_and_load_table(self): - if self.sqlCtx is None: - return # no hive available, skipped - df = self.df tmpPath = tempfile.mkdtemp() shutil.rmtree(tmpPath) @@ -805,6 +801,27 @@ def test_save_and_load_table(self): shutil.rmtree(tmpPath) + def test_window_functions(self): + df = self.sqlCtx.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"]) + w = Window.partitionBy("value").orderBy("key") + from pyspark.sql import functions as F + sel = df.select(df.value, df.key, + F.max("key").over(w.rowsBetween(0, 1)), + F.min("key").over(w.rowsBetween(0, 1)), + F.count("key").over(w.rowsBetween(float('-inf'), float('inf'))), + F.rowNumber().over(w), + F.rank().over(w), + F.denseRank().over(w), + F.ntile(2).over(w)) + rs = sorted(sel.collect()) + expected = [ + ("1", 1, 1, 1, 1, 1, 1, 1, 1), + ("2", 1, 1, 1, 3, 1, 1, 1, 1), + ("2", 1, 2, 1, 3, 2, 1, 1, 1), + ("2", 2, 2, 2, 3, 3, 3, 2, 2) + ] + for r, ex in zip(rs, expected): + self.assertEqual(tuple(r), ex[:len(r)]) if __name__ == "__main__": unittest.main() diff --git a/python/pyspark/sql/window.py b/python/pyspark/sql/window.py new file mode 100644 index 0000000000000..0a0e006bdf83a --- /dev/null +++ b/python/pyspark/sql/window.py @@ -0,0 +1,158 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import sys + +from pyspark import SparkContext +from pyspark.sql import since +from pyspark.sql.column import _to_seq, _to_java_column + +__all__ = ["Window", "WindowSpec"] + + +def _to_java_cols(cols): + sc = SparkContext._active_spark_context + if len(cols) == 1 and isinstance(cols[0], list): + cols = cols[0] + return _to_seq(sc, cols, _to_java_column) + + +class Window(object): + + """ + Utility functions for defining window in DataFrames. + + For example: + + >>> # PARTITION BY country ORDER BY date ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW + >>> window = Window.partitionBy("country").orderBy("date").rowsBetween(-sys.maxsize, 0) + + >>> # PARTITION BY country ORDER BY date RANGE BETWEEN 3 PRECEDING AND 3 FOLLOWING + >>> window = Window.orderBy("date").partitionBy("country").rangeBetween(-3, 3) + + .. note:: Experimental + + .. versionadded:: 1.4 + """ + @staticmethod + @since(1.4) + def partitionBy(*cols): + """ + Creates a :class:`WindowSpec` with the partitioning defined. + """ + sc = SparkContext._active_spark_context + jspec = sc._jvm.org.apache.spark.sql.expressions.Window.partitionBy(_to_java_cols(cols)) + return WindowSpec(jspec) + + @staticmethod + @since(1.4) + def orderBy(*cols): + """ + Creates a :class:`WindowSpec` with the partitioning defined. + """ + sc = SparkContext._active_spark_context + jspec = sc._jvm.org.apache.spark.sql.expressions.Window.partitionBy(_to_java_cols(cols)) + return WindowSpec(jspec) + + +class WindowSpec(object): + """ + A window specification that defines the partitioning, ordering, + and frame boundaries. + + Use the static methods in :class:`Window` to create a :class:`WindowSpec`. + + .. note:: Experimental + + .. versionadded:: 1.4 + """ + + _JAVA_MAX_LONG = (1 << 63) - 1 + _JAVA_MIN_LONG = - (1 << 63) + + def __init__(self, jspec): + self._jspec = jspec + + @since(1.4) + def partitionBy(self, *cols): + """ + Defines the partitioning columns in a :class:`WindowSpec`. + + :param cols: names of columns or expressions + """ + return WindowSpec(self._jspec.partitionBy(_to_java_cols(cols))) + + @since(1.4) + def orderBy(self, *cols): + """ + Defines the ordering columns in a :class:`WindowSpec`. + + :param cols: names of columns or expressions + """ + return WindowSpec(self._jspec.orderBy(_to_java_cols(cols))) + + @since(1.4) + def rowsBetween(self, start, end): + """ + Defines the frame boundaries, from `start` (inclusive) to `end` (inclusive). + + Both `start` and `end` are relative positions from the current row. + For example, "0" means "current row", while "-1" means the row before + the current row, and "5" means the fifth row after the current row. + + :param start: boundary start, inclusive. + The frame is unbounded if this is ``-sys.maxsize`` (or lower). + :param end: boundary end, inclusive. + The frame is unbounded if this is ``sys.maxsize`` (or higher). + """ + if start <= -sys.maxsize: + start = self._JAVA_MIN_LONG + if end >= sys.maxsize: + end = self._JAVA_MAX_LONG + return WindowSpec(self._jspec.rowsBetween(start, end)) + + @since(1.4) + def rangeBetween(self, start, end): + """ + Defines the frame boundaries, from `start` (inclusive) to `end` (inclusive). + + Both `start` and `end` are relative from the current row. For example, + "0" means "current row", while "-1" means one off before the current row, + and "5" means the five off after the current row. + + :param start: boundary start, inclusive. + The frame is unbounded if this is ``-sys.maxsize`` (or lower). + :param end: boundary end, inclusive. + The frame is unbounded if this is ``sys.maxsize`` (or higher). + """ + if start <= -sys.maxsize: + start = self._JAVA_MIN_LONG + if end >= sys.maxsize: + end = self._JAVA_MAX_LONG + return WindowSpec(self._jspec.rangeBetween(start, end)) + + +def _test(): + import doctest + SparkContext('local[4]', 'PythonTest') + (failure_count, test_count) = doctest.testmod() + if failure_count: + exit(-1) + + +if __name__ == "__main__": + _test() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala index 4190b7ffe1c8f..0d460b634d9b0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala @@ -55,6 +55,9 @@ object Row { // TODO: Improve the performance of this if used in performance critical part. new GenericRow(rows.flatMap(_.toSeq).toArray) } + + /** Returns an empty row. */ + val empty = apply() } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 06a0504359f6e..193dc6b6546b5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -71,6 +71,11 @@ trait CheckAnalysis { s"invalid expression ${b.prettyString} " + s"between ${b.left.dataType.simpleString} and ${b.right.dataType.simpleString}") + case WindowExpression(UnresolvedWindowFunction(name, _), _) => + failAnalysis( + s"Could not resolve window function '$name'. " + + "Note that, using window functions currently requires a HiveContext") + case w @ WindowExpression(windowFunction, windowSpec) if windowSpec.validate.nonEmpty => // The window spec is not valid. val reason = windowSpec.validate.get diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 4c0d70203c6f5..307a9ca9b0070 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -276,6 +276,10 @@ package object dsl { def unionAll(otherPlan: LogicalPlan): LogicalPlan = Union(logicalPlan, otherPlan) + def except(otherPlan: LogicalPlan): LogicalPlan = Except(logicalPlan, otherPlan) + + def intersect(otherPlan: LogicalPlan): LogicalPlan = Intersect(logicalPlan, otherPlan) + def sfilter[T1](arg1: Symbol)(udf: (T1) => Boolean): LogicalPlan = Filter(ScalaUdf(udf, BooleanType, Seq(UnresolvedAttribute(arg1.name))), logicalPlan) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index e1d6ac462fbcc..939cefb71b817 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -166,6 +166,19 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter { } } + errorTest( + "unresolved window function", + testRelation2.select( + WindowExpression( + UnresolvedWindowFunction( + "lead", + UnresolvedAttribute("c") :: Nil), + WindowSpecDefinition( + UnresolvedAttribute("a") :: Nil, + SortOrder(UnresolvedAttribute("b"), Ascending) :: Nil, + UnspecifiedFrame)).as('window)), + "lead" :: "window functions currently requires a HiveContext" :: Nil) + errorTest( "too many generators", listRelation.select(Explode('list).as('a), Explode('list).as('b)), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index dc0aeea7c4aea..6895aa1010956 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -18,13 +18,13 @@ package org.apache.spark.sql import scala.language.implicitConversions -import scala.collection.JavaConversions._ import org.apache.spark.annotation.Experimental import org.apache.spark.Logging +import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.functions.lit import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.analysis.{MultiAlias, UnresolvedAttribute, UnresolvedStar, UnresolvedExtractValue} +import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.types._ @@ -889,6 +889,22 @@ class Column(protected[sql] val expr: Expression) extends Logging { */ def bitwiseXOR(other: Any): Column = BitwiseXor(expr, lit(other).expr) + /** + * Define a windowing column. + * + * {{{ + * val w = Window.partitionBy("name").orderBy("id") + * df.select( + * sum("price").over(w.rangeBetween(Long.MinValue, 2)), + * avg("price").over(w.rowsBetween(0, 4)) + * ) + * }}} + * + * @group expr_ops + * @since 1.4.0 + */ + def over(window: expressions.WindowSpec): Column = window.withAggregate(this) + } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index d78b4c2f8909c..f968577bc5848 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -37,7 +37,7 @@ import org.apache.spark.sql.catalyst.analysis.{MultiAlias, ResolvedStar, Unresol import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.{Filter, _} import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} -import org.apache.spark.sql.catalyst.{expressions, CatalystTypeConverters, ScalaReflection, SqlParser} +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection, SqlParser} import org.apache.spark.sql.execution.{EvaluatePython, ExplainCommand, LogicalRDD} import org.apache.spark.sql.json.JacksonGenerator import org.apache.spark.sql.sources.CreateTableUsingAsSelect @@ -411,7 +411,7 @@ class DataFrame private[sql]( joined.left, joined.right, joinType = Inner, - Some(expressions.EqualTo( + Some(catalyst.expressions.EqualTo( joined.left.resolve(usingColumn), joined.right.resolve(usingColumn)))) ) @@ -480,8 +480,9 @@ class DataFrame private[sql]( // By the time we get here, since we have already run analysis, all attributes should've been // resolved and become AttributeReference. val cond = plan.condition.map { _.transform { - case expressions.EqualTo(a: AttributeReference, b: AttributeReference) if a.sameRef(b) => - expressions.EqualTo(plan.left.resolve(a.name), plan.right.resolve(b.name)) + case catalyst.expressions.EqualTo(a: AttributeReference, b: AttributeReference) + if a.sameRef(b) => + catalyst.expressions.EqualTo(plan.left.resolve(a.name), plan.right.resolve(b.name)) }} plan.copy(condition = cond) } @@ -1393,28 +1394,6 @@ class DataFrame private[sql]( @Experimental def write: DataFrameWriter = new DataFrameWriter(this) - /** - * :: Experimental :: - * Adds the rows from this RDD to the specified table, optionally overwriting the existing data. - * @group output - * @since 1.3.0 - */ - @Experimental - def insertInto(tableName: String, overwrite: Boolean): Unit = { - sqlContext.executePlan(InsertIntoTable(UnresolvedRelation(Seq(tableName)), - Map.empty, logicalPlan, overwrite, ifNotExists = false)).toRdd - } - - /** - * :: Experimental :: - * Adds the rows from this RDD to the specified table. - * Throws an exception if the table already exists. - * @group output - * @since 1.3.0 - */ - @Experimental - def insertInto(tableName: String): Unit = insertInto(tableName, overwrite = false) - /** * Returns the content of the [[DataFrame]] as a RDD of JSON strings. * @group rdd @@ -1550,13 +1529,7 @@ class DataFrame private[sql]( */ @deprecated("Use write.mode(mode).saveAsTable(tableName)", "1.4.0") def saveAsTable(tableName: String, mode: SaveMode): Unit = { - if (sqlContext.catalog.tableExists(Seq(tableName)) && mode == SaveMode.Append) { - // If table already exists and the save mode is Append, - // we will just call insertInto to append the contents of this DataFrame. - insertInto(tableName, overwrite = false) - } else { - write.mode(mode).saveAsTable(tableName) - } + write.mode(mode).saveAsTable(tableName) } /** @@ -1712,9 +1685,29 @@ class DataFrame private[sql]( write.format(source).mode(mode).options(options).save() } + + /** + * Adds the rows from this RDD to the specified table, optionally overwriting the existing data. + * @group output + */ + @deprecated("Use write.mode(SaveMode.Append|SaveMode.Overwrite).saveAsTable(tableName)", "1.4.0") + def insertInto(tableName: String, overwrite: Boolean): Unit = { + write.mode(if (overwrite) SaveMode.Overwrite else SaveMode.Append).insertInto(tableName) + } + + /** + * Adds the rows from this RDD to the specified table. + * Throws an exception if the table already exists. + * @group output + */ + @deprecated("Use write.mode(SaveMode.Append).saveAsTable(tableName)", "1.4.0") + def insertInto(tableName: String): Unit = { + write.mode(SaveMode.Append).insertInto(tableName) + } + //////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////// - // End of eeprecated methods + // End of deprecated methods //////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 381c10f48f3c3..b44d4c86ac5d3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -94,20 +94,6 @@ class DataFrameReader private[sql](sqlContext: SQLContext) { this } - /** - * Specifies the input partitioning. If specified, the underlying data source does not need to - * discover the data partitioning scheme, and thus can speed up very large inputs. - * - * This is only applicable for Parquet at the moment. - * - * @since 1.4.0 - */ - @scala.annotation.varargs - def partitionBy(colNames: String*): DataFrameReader = { - this.partitioningColumns = Option(colNames) - this - } - /** * Loads input in as a [[DataFrame]], for data sources that require a path (e.g. data backed by * a local or distributed file system). @@ -128,7 +114,7 @@ class DataFrameReader private[sql](sqlContext: SQLContext) { val resolved = ResolvedDataSource( sqlContext, userSpecifiedSchema = userSpecifiedSchema, - partitionColumns = partitioningColumns.map(_.toArray).getOrElse(Array.empty[String]), + partitionColumns = Array.empty[String], provider = source, options = extraOptions.toMap) DataFrame(sqlContext, LogicalRelation(resolved.relation)) @@ -300,6 +286,4 @@ class DataFrameReader private[sql](sqlContext: SQLContext) { private var extraOptions = new scala.collection.mutable.HashMap[String, String] - private var partitioningColumns: Option[Seq[String]] = None - } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index f2e721d4db271..5548b26cb8f80 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -20,6 +20,8 @@ package org.apache.spark.sql import java.util.Properties import org.apache.spark.annotation.Experimental +import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation +import org.apache.spark.sql.catalyst.plans.logical.InsertIntoTable import org.apache.spark.sql.jdbc.{JDBCWriteDetails, JdbcUtils} import org.apache.spark.sql.sources.{ResolvedDataSource, CreateTableUsingAsSelect} @@ -148,22 +150,66 @@ final class DataFrameWriter private[sql](df: DataFrame) { df) } + /** + * Inserts the content of the [[DataFrame]] to the specified table. It requires that + * the schema of the [[DataFrame]] is the same as the schema of the table. + * + * Because it inserts data to an existing table, format or options will be ignored. + * + * @since 1.4.0 + */ + def insertInto(tableName: String): Unit = { + val partitions = + partitioningColumns.map(_.map(col => col -> (None: Option[String])).toMap) + val overwrite = (mode == SaveMode.Overwrite) + df.sqlContext.executePlan(InsertIntoTable( + UnresolvedRelation(Seq(tableName)), + partitions.getOrElse(Map.empty[String, Option[String]]), + df.logicalPlan, + overwrite, + ifNotExists = false)).toRdd + } + /** * Saves the content of the [[DataFrame]] as the specified table. * + * In the case the table already exists, behavior of this function depends on the + * save mode, specified by the `mode` function (default to throwing an exception). + * When `mode` is `Overwrite`, the schema of the [[DataFrame]] does not need to be + * the same as that of the existing table. + * When `mode` is `Append`, the schema of the [[DataFrame]] need to be + * the same as that of the existing table, and format or options will be ignored. + * * @since 1.4.0 */ def saveAsTable(tableName: String): Unit = { - val cmd = - CreateTableUsingAsSelect( - tableName, - source, - temporary = false, - partitioningColumns.map(_.toArray).getOrElse(Array.empty[String]), - mode, - extraOptions.toMap, - df.logicalPlan) - df.sqlContext.executePlan(cmd).toRdd + if (df.sqlContext.catalog.tableExists(tableName :: Nil) && mode != SaveMode.Overwrite) { + mode match { + case SaveMode.Ignore => + // Do nothing + + case SaveMode.ErrorIfExists => + throw new AnalysisException(s"Table $tableName already exists.") + + case SaveMode.Append => + // If it is Append, we just ask insertInto to handle it. We will not use insertInto + // to handle saveAsTable with Overwrite because saveAsTable can change the schema of + // the table. But, insertInto with Overwrite requires the schema of data be the same + // the schema of the table. + insertInto(tableName) + } + } else { + val cmd = + CreateTableUsingAsSelect( + tableName, + source, + temporary = false, + partitioningColumns.map(_.toArray).getOrElse(Array.empty[String]), + mode, + extraOptions.toMap, + df.logicalPlan) + df.sqlContext.executePlan(cmd).toRdd + } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala index 0ded1cce68391..a59d42cdd6028 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala @@ -314,7 +314,7 @@ private[sql] case class InMemoryColumnarTableScan( columnAccessors(i).extractTo(nextRow, i) i += 1 } - nextRow + if (attributes.isEmpty) Row.empty else nextRow } override def hasNext: Boolean = columnAccessors(0).hasNext diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala new file mode 100644 index 0000000000000..d4003b2d9cbf6 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.expressions + +import org.apache.spark.annotation.Experimental +import org.apache.spark.sql.Column +import org.apache.spark.sql.catalyst.expressions._ + +/** + * :: Experimental :: + * Utility functions for defining window in DataFrames. + * + * {{{ + * // PARTITION BY country ORDER BY date ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW + * Window.partitionBy("country").orderBy("date").rowsBetween(Long.MinValue, 0) + * + * // PARTITION BY country ORDER BY date ROWS BETWEEN 3 PRECEDING AND 3 FOLLOWING + * Window.partitionBy("country").orderBy("date").rowsBetween(-3, 3) + * }}} + * + * @since 1.4.0 + */ +@Experimental +object Window { + + /** + * Creates a [[WindowSpec]] with the partitioning defined. + * @since 1.4.0 + */ + @scala.annotation.varargs + def partitionBy(colName: String, colNames: String*): WindowSpec = { + spec.partitionBy(colName, colNames : _*) + } + + /** + * Creates a [[WindowSpec]] with the partitioning defined. + * @since 1.4.0 + */ + @scala.annotation.varargs + def partitionBy(cols: Column*): WindowSpec = { + spec.partitionBy(cols : _*) + } + + /** + * Creates a [[WindowSpec]] with the ordering defined. + * @since 1.4.0 + */ + @scala.annotation.varargs + def orderBy(colName: String, colNames: String*): WindowSpec = { + spec.orderBy(colName, colNames : _*) + } + + /** + * Creates a [[WindowSpec]] with the ordering defined. + * @since 1.4.0 + */ + @scala.annotation.varargs + def orderBy(cols: Column*): WindowSpec = { + spec.orderBy(cols : _*) + } + + private def spec: WindowSpec = { + new WindowSpec(Seq.empty, Seq.empty, UnspecifiedFrame) + } + +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala new file mode 100644 index 0000000000000..c3d2246297021 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala @@ -0,0 +1,175 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.expressions + +import org.apache.spark.annotation.Experimental +import org.apache.spark.sql.{Column, catalyst} +import org.apache.spark.sql.catalyst.expressions._ + + +/** + * :: Experimental :: + * A window specification that defines the partitioning, ordering, and frame boundaries. + * + * Use the static methods in [[Window]] to create a [[WindowSpec]]. + * + * @since 1.4.0 + */ +@Experimental +class WindowSpec private[sql]( + partitionSpec: Seq[Expression], + orderSpec: Seq[SortOrder], + frame: catalyst.expressions.WindowFrame) { + + /** + * Defines the partitioning columns in a [[WindowSpec]]. + * @since 1.4.0 + */ + @scala.annotation.varargs + def partitionBy(colName: String, colNames: String*): WindowSpec = { + partitionBy((colName +: colNames).map(Column(_)): _*) + } + + /** + * Defines the partitioning columns in a [[WindowSpec]]. + * @since 1.4.0 + */ + @scala.annotation.varargs + def partitionBy(cols: Column*): WindowSpec = { + new WindowSpec(cols.map(_.expr), orderSpec, frame) + } + + /** + * Defines the ordering columns in a [[WindowSpec]]. + * @since 1.4.0 + */ + @scala.annotation.varargs + def orderBy(colName: String, colNames: String*): WindowSpec = { + orderBy((colName +: colNames).map(Column(_)): _*) + } + + /** + * Defines the ordering columns in a [[WindowSpec]]. + * @since 1.4.0 + */ + @scala.annotation.varargs + def orderBy(cols: Column*): WindowSpec = { + val sortOrder: Seq[SortOrder] = cols.map { col => + col.expr match { + case expr: SortOrder => + expr + case expr: Expression => + SortOrder(expr, Ascending) + } + } + new WindowSpec(partitionSpec, sortOrder, frame) + } + + /** + * Defines the frame boundaries, from `start` (inclusive) to `end` (inclusive). + * + * Both `start` and `end` are relative positions from the current row. For example, "0" means + * "current row", while "-1" means the row before the current row, and "5" means the fifth row + * after the current row. + * + * @param start boundary start, inclusive. + * The frame is unbounded if this is the minimum long value. + * @param end boundary end, inclusive. + * The frame is unbounded if this is the maximum long value. + * @since 1.4.0 + */ + def rowsBetween(start: Long, end: Long): WindowSpec = { + between(RowFrame, start, end) + } + + /** + * Defines the frame boundaries, from `start` (inclusive) to `end` (inclusive). + * + * Both `start` and `end` are relative from the current row. For example, "0" means "current row", + * while "-1" means one off before the current row, and "5" means the five off after the + * current row. + * + * @param start boundary start, inclusive. + * The frame is unbounded if this is the minimum long value. + * @param end boundary end, inclusive. + * The frame is unbounded if this is the maximum long value. + * @since 1.4.0 + */ + def rangeBetween(start: Long, end: Long): WindowSpec = { + between(RangeFrame, start, end) + } + + private def between(typ: FrameType, start: Long, end: Long): WindowSpec = { + val boundaryStart = start match { + case 0 => CurrentRow + case Long.MinValue => UnboundedPreceding + case x if x < 0 => ValuePreceding(-start.toInt) + case x if x > 0 => ValueFollowing(start.toInt) + } + + val boundaryEnd = end match { + case 0 => CurrentRow + case Long.MaxValue => UnboundedFollowing + case x if x < 0 => ValuePreceding(-end.toInt) + case x if x > 0 => ValueFollowing(end.toInt) + } + + new WindowSpec( + partitionSpec, + orderSpec, + SpecifiedWindowFrame(typ, boundaryStart, boundaryEnd)) + } + + /** + * Converts this [[WindowSpec]] into a [[Column]] with an aggregate expression. + */ + private[sql] def withAggregate(aggregate: Column): Column = { + val windowExpr = aggregate.expr match { + case Average(child) => WindowExpression( + UnresolvedWindowFunction("avg", child :: Nil), + WindowSpecDefinition(partitionSpec, orderSpec, frame)) + case Sum(child) => WindowExpression( + UnresolvedWindowFunction("sum", child :: Nil), + WindowSpecDefinition(partitionSpec, orderSpec, frame)) + case Count(child) => WindowExpression( + UnresolvedWindowFunction("count", child :: Nil), + WindowSpecDefinition(partitionSpec, orderSpec, frame)) + case First(child) => WindowExpression( + // TODO this is a hack for Hive UDAF first_value + UnresolvedWindowFunction("first_value", child :: Nil), + WindowSpecDefinition(partitionSpec, orderSpec, frame)) + case Last(child) => WindowExpression( + // TODO this is a hack for Hive UDAF last_value + UnresolvedWindowFunction("last_value", child :: Nil), + WindowSpecDefinition(partitionSpec, orderSpec, frame)) + case Min(child) => WindowExpression( + UnresolvedWindowFunction("min", child :: Nil), + WindowSpecDefinition(partitionSpec, orderSpec, frame)) + case Max(child) => WindowExpression( + UnresolvedWindowFunction("max", child :: Nil), + WindowSpecDefinition(partitionSpec, orderSpec, frame)) + case wf: WindowFunction => WindowExpression( + wf, + WindowSpecDefinition(partitionSpec, orderSpec, frame)) + case x => + throw new UnsupportedOperationException(s"$x is not supported in window operation.") + } + new Column(windowExpr) + } + +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 6640631cf0719..9a23cfb89ca12 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -37,6 +37,7 @@ import org.apache.spark.util.Utils * @groupname sort_funcs Sorting functions * @groupname normal_funcs Non-aggregate functions * @groupname math_funcs Math functions + * @groupname window_funcs Window functions * @groupname Ungrouped Support functions for DataFrames. * @since 1.3.0 */ @@ -320,6 +321,218 @@ object functions { */ def max(columnName: String): Column = max(Column(columnName)) + ////////////////////////////////////////////////////////////////////////////////////////////// + // Window functions + ////////////////////////////////////////////////////////////////////////////////////////////// + + /** + * Window function: returns the value that is `offset` rows before the current row, and + * `null` if there is less than `offset` rows before the current row. For example, + * an `offset` of one will return the previous row at any given point in the window partition. + * + * This is equivalent to the LAG function in SQL. + * + * @group window_funcs + * @since 1.4.0 + */ + def lag(e: Column, offset: Int): Column = { + lag(e, offset, null) + } + + /** + * Window function: returns the value that is `offset` rows before the current row, and + * `null` if there is less than `offset` rows before the current row. For example, + * an `offset` of one will return the previous row at any given point in the window partition. + * + * This is equivalent to the LAG function in SQL. + * + * @group window_funcs + * @since 1.4.0 + */ + def lag(columnName: String, offset: Int): Column = { + lag(columnName, offset, null) + } + + /** + * Window function: returns the value that is `offset` rows before the current row, and + * `defaultValue` if there is less than `offset` rows before the current row. For example, + * an `offset` of one will return the previous row at any given point in the window partition. + * + * This is equivalent to the LAG function in SQL. + * + * @group window_funcs + * @since 1.4.0 + */ + def lag(columnName: String, offset: Int, defaultValue: Any): Column = { + lag(Column(columnName), offset, defaultValue) + } + + /** + * Window function: returns the value that is `offset` rows before the current row, and + * `defaultValue` if there is less than `offset` rows before the current row. For example, + * an `offset` of one will return the previous row at any given point in the window partition. + * + * This is equivalent to the LAG function in SQL. + * + * @group window_funcs + * @since 1.4.0 + */ + def lag(e: Column, offset: Int, defaultValue: Any): Column = { + UnresolvedWindowFunction("lag", e.expr :: Literal(offset) :: Literal(defaultValue) :: Nil) + } + + /** + * Window function: returns the value that is `offset` rows after the current row, and + * `null` if there is less than `offset` rows after the current row. For example, + * an `offset` of one will return the next row at any given point in the window partition. + * + * This is equivalent to the LEAD function in SQL. + * + * @group window_funcs + * @since 1.4.0 + */ + def lead(columnName: String, offset: Int): Column = { + lead(columnName, offset, null) + } + + /** + * Window function: returns the value that is `offset` rows after the current row, and + * `null` if there is less than `offset` rows after the current row. For example, + * an `offset` of one will return the next row at any given point in the window partition. + * + * This is equivalent to the LEAD function in SQL. + * + * @group window_funcs + * @since 1.4.0 + */ + def lead(e: Column, offset: Int): Column = { + lead(e, offset, null) + } + + /** + * Window function: returns the value that is `offset` rows after the current row, and + * `defaultValue` if there is less than `offset` rows after the current row. For example, + * an `offset` of one will return the next row at any given point in the window partition. + * + * This is equivalent to the LEAD function in SQL. + * + * @group window_funcs + * @since 1.4.0 + */ + def lead(columnName: String, offset: Int, defaultValue: Any): Column = { + lead(Column(columnName), offset, defaultValue) + } + + /** + * Window function: returns the value that is `offset` rows after the current row, and + * `defaultValue` if there is less than `offset` rows after the current row. For example, + * an `offset` of one will return the next row at any given point in the window partition. + * + * This is equivalent to the LEAD function in SQL. + * + * @group window_funcs + * @since 1.4.0 + */ + def lead(e: Column, offset: Int, defaultValue: Any): Column = { + UnresolvedWindowFunction("lead", e.expr :: Literal(offset) :: Literal(defaultValue) :: Nil) + } + + /** + * Window function: returns the ntile group id (from 1 to `n` inclusive) in an ordered window + * partition. Fow example, if `n` is 4, the first quarter of the rows will get value 1, the second + * quarter will get 2, the third quarter will get 3, and the last quarter will get 4. + * + * This is equivalent to the NTILE function in SQL. + * + * @group window_funcs + * @since 1.4.0 + */ + def ntile(n: Int): Column = { + UnresolvedWindowFunction("ntile", lit(n).expr :: Nil) + } + + /** + * Window function: returns a sequential number starting at 1 within a window partition. + * + * This is equivalent to the ROW_NUMBER function in SQL. + * + * @group window_funcs + * @since 1.4.0 + */ + def rowNumber(): Column = { + UnresolvedWindowFunction("row_number", Nil) + } + + /** + * Window function: returns the rank of rows within a window partition, without any gaps. + * + * The difference between rank and denseRank is that denseRank leaves no gaps in ranking + * sequence when there are ties. That is, if you were ranking a competition using denseRank + * and had three people tie for second place, you would say that all three were in second + * place and that the next person came in third. + * + * This is equivalent to the DENSE_RANK function in SQL. + * + * @group window_funcs + * @since 1.4.0 + */ + def denseRank(): Column = { + UnresolvedWindowFunction("dense_rank", Nil) + } + + /** + * Window function: returns the rank of rows within a window partition. + * + * The difference between rank and denseRank is that denseRank leaves no gaps in ranking + * sequence when there are ties. That is, if you were ranking a competition using denseRank + * and had three people tie for second place, you would say that all three were in second + * place and that the next person came in third. + * + * This is equivalent to the RANK function in SQL. + * + * @group window_funcs + * @since 1.4.0 + */ + def rank(): Column = { + UnresolvedWindowFunction("rank", Nil) + } + + /** + * Window function: returns the cumulative distribution of values within a window partition, + * i.e. the fraction of rows that are below the current row. + * + * {{{ + * N = total number of rows in the partition + * cumeDist(x) = number of values before (and including) x / N + * }}} + * + * + * This is equivalent to the CUME_DIST function in SQL. + * + * @group window_funcs + * @since 1.4.0 + */ + def cumeDist(): Column = { + UnresolvedWindowFunction("cume_dist", Nil) + } + + /** + * Window function: returns the relative rank (i.e. percentile) of rows within a window partition. + * + * This is computed by: + * {{{ + * (rank of row in its partition - 1) / (number of rows in the partition - 1) + * }}} + * + * This is equivalent to the PERCENT_RANK function in SQL. + * + * @group window_funcs + * @since 1.4.0 + */ + def percentRank(): Column = { + UnresolvedWindowFunction("percent_rank", Nil) + } + ////////////////////////////////////////////////////////////////////////////////////////////// // Non-aggregate functions ////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala index c45c431438efc..70a220cc43ab9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala @@ -129,7 +129,7 @@ private[parquet] object RowReadSupport { } /** - * A `parquet.hadoop.api.WriteSupport` for Row ojects. + * A `parquet.hadoop.api.WriteSupport` for Row objects. */ private[parquet] class RowWriteSupport extends WriteSupport[Row] with Logging { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala index c03649d00bbae..dacd967cff856 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala @@ -105,10 +105,9 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { execution.ExecutedCommand(InsertIntoDataSource(l, query, overwrite)) :: Nil case i @ logical.InsertIntoTable( - l @ LogicalRelation(t: HadoopFsRelation), part, query, overwrite, false) if part.isEmpty => + l @ LogicalRelation(t: HadoopFsRelation), part, query, overwrite, false) => val mode = if (overwrite) SaveMode.Overwrite else SaveMode.Append - execution.ExecutedCommand( - InsertIntoHadoopFsRelation(t, query, Array.empty[String], mode)) :: Nil + execution.ExecutedCommand(InsertIntoHadoopFsRelation(t, query, mode)) :: Nil case _ => Nil } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala index 498f7538d4f55..fbd98ef0380e1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala @@ -61,7 +61,6 @@ private[sql] case class InsertIntoDataSource( private[sql] case class InsertIntoHadoopFsRelation( @transient relation: HadoopFsRelation, @transient query: LogicalPlan, - partitionColumns: Array[String], mode: SaveMode) extends RunnableCommand { @@ -100,6 +99,7 @@ private[sql] case class InsertIntoHadoopFsRelation( relation.schema, needsConversion = false) + val partitionColumns = relation.partitionColumns.fieldNames if (partitionColumns.isEmpty) { insert(new DefaultWriterContainer(relation, job), df) } else { @@ -377,13 +377,22 @@ private[sql] class DefaultWriterContainer( override def outputWriterForRow(row: Row): OutputWriter = writer override def commitTask(): Unit = { - writer.close() - super.commitTask() + try { + writer.close() + super.commitTask() + } catch { + case cause: Throwable => + super.abortTask() + throw new RuntimeException("Failed to commit task", cause) + } } override def abortTask(): Unit = { - writer.close() - super.abortTask() + try { + writer.close() + } finally { + super.abortTask() + } } } @@ -422,13 +431,21 @@ private[sql] class DynamicPartitionWriterContainer( } override def commitTask(): Unit = { - outputWriters.values.foreach(_.close()) - super.commitTask() + try { + outputWriters.values.foreach(_.close()) + super.commitTask() + } catch { case cause: Throwable => + super.abortTask() + throw new RuntimeException("Failed to commit task", cause) + } } override def abortTask(): Unit = { - outputWriters.values.foreach(_.close()) - super.abortTask() + try { + outputWriters.values.foreach(_.close()) + } finally { + super.abortTask() + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala index 5e723122eeab1..ca30b8e74626f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala @@ -335,7 +335,6 @@ private[sql] object ResolvedDataSource { InsertIntoHadoopFsRelation( r, project, - partitionColumns.toArray, mode)).toRdd r case _ => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/rules.scala index ab33125b74c17..a3fd7f13b3db7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/rules.scala @@ -35,9 +35,9 @@ private[sql] object PreInsertCastAndRename extends Rule[LogicalPlan] { // Wait until children are resolved. case p: LogicalPlan if !p.childrenResolved => p - // We are inserting into an InsertableRelation. + // We are inserting into an InsertableRelation or HadoopFsRelation. case i @ InsertIntoTable( - l @ LogicalRelation(r: InsertableRelation), partition, child, overwrite, ifNotExists) => { + l @ LogicalRelation(_: InsertableRelation | _: HadoopFsRelation), _, child, _, _) => { // First, make sure the data to be inserted have the same number of fields with the // schema of the relation. if (l.output.size != child.output.size) { @@ -101,7 +101,20 @@ private[sql] case class PreWriteCheck(catalog: Catalog) extends (LogicalPlan => } } - case logical.InsertIntoTable(LogicalRelation(_: HadoopFsRelation), _, _, _, _) => // OK + case logical.InsertIntoTable(LogicalRelation(r: HadoopFsRelation), part, _, _, _) => + // We need to make sure the partition columns specified by users do match partition + // columns of the relation. + val existingPartitionColumns = r.partitionColumns.fieldNames.toSet + val specifiedPartitionColumns = part.keySet + if (existingPartitionColumns != specifiedPartitionColumns) { + failAnalysis(s"Specified partition columns " + + s"(${specifiedPartitionColumns.mkString(", ")}) " + + s"do not match the partition columns of the table. Please use " + + s"(${existingPartitionColumns.mkString(", ")}) as the partition columns.") + } else { + // OK + } + case logical.InsertIntoTable(l: LogicalRelation, _, _, _, _) => // The relation in l is not an InsertableRelation. failAnalysis(s"$l does not allow insertion.") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/test/README.md b/sql/core/src/main/scala/org/apache/spark/sql/test/README.md new file mode 100644 index 0000000000000..d867f181b9728 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/test/README.md @@ -0,0 +1,7 @@ +README +====== + +Please do not add any class in this place unless it is used by `sql/console` or Python tests. +If you need to create any classes or traits that will be used by tests from both `sql/core` and +`sql/hive`, you can add them in the `src/test` of `sql/core` (tests of `sql/hive` +depend on the test jar of `sql/core`). diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index c5c4f448a7224..7c47fe454b6dc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -39,6 +39,19 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { import org.apache.spark.sql.test.TestSQLContext.implicits._ val sqlCtx = TestSQLContext + test("SPARK-6743: no columns from cache") { + Seq( + (83, 0, 38), + (26, 0, 79), + (43, 81, 24) + ).toDF("a", "b", "c").registerTempTable("cachedData") + + cacheTable("cachedData") + checkAnswer( + sql("SELECT t1.b FROM cachedData, cachedData t1 GROUP BY t1.b"), + Row(0) :: Row(81) :: Nil) + } + test("self join with aliases") { Seq(1,2,3).map(i => (i, i.toString)).toDF("int", "str").registerTempTable("df") @@ -142,7 +155,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { sql("SELECT ABS(2.5)"), Row(2.5)) } - + test("aggregation with codegen") { val originalValue = conf.codegenEnabled setConf(SQLConf.CODEGEN_ENABLED, "true") @@ -194,7 +207,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { "SELECT value, sum(key) FROM testData3x GROUP BY value", (1 to 100).map(i => Row(i.toString, 3 * i))) testCodeGen( - "SELECT sum(key), SUM(CAST(key as Double)) FROM testData3x", + "SELECT sum(key), SUM(CAST(key as Double)) FROM testData3x", Row(5050 * 3, 5050 * 3.0) :: Nil) // AVERAGE testCodeGen( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetTest.scala similarity index 100% rename from sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTest.scala rename to sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetTest.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala similarity index 100% rename from sql/core/src/main/scala/org/apache/spark/sql/test/SQLTestUtils.scala rename to sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/UISeleniumSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/UISeleniumSuite.scala index 47541015a3611..a286dc5825f77 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/UISeleniumSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/UISeleniumSuite.scala @@ -17,21 +17,18 @@ package org.apache.spark.sql.hive.thriftserver - - import scala.util.Random +import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.openqa.selenium.WebDriver import org.openqa.selenium.htmlunit.HtmlUnitDriver -import org.scalatest.{Matchers, BeforeAndAfterAll} import org.scalatest.concurrent.Eventually._ import org.scalatest.selenium.WebBrowser import org.scalatest.time.SpanSugar._ +import org.scalatest.{BeforeAndAfterAll, Matchers} -import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.spark.sql.hive.HiveContext - class UISeleniumSuite extends HiveThriftJdbcTest with WebBrowser with Matchers with BeforeAndAfterAll { @@ -75,9 +72,9 @@ class UISeleniumSuite """.stripMargin.split("\\s+").toSeq } - test("thrift server ui test") { + ignore("thrift server ui test") { withJdbcStatement(statement =>{ - val baseURL = s"http://localhost:${uiPort}" + val baseURL = s"http://localhost:$uiPort" val queries = Seq( "CREATE TABLE test_map(key INT, value STRING)", @@ -86,14 +83,14 @@ class UISeleniumSuite queries.foreach(statement.execute) eventually(timeout(10 seconds), interval(50 milliseconds)) { - go to (baseURL) - find(cssSelector("""ul li a[href*="ThriftServer"]""")) should not be(None) + go to baseURL + find(cssSelector("""ul li a[href*="ThriftServer"]""")) should not be None } eventually(timeout(10 seconds), interval(50 milliseconds)) { go to (baseURL + "/ThriftServer") - find(id("sessionstat")) should not be(None) - find(id("sqlstat")) should not be(None) + find(id("sessionstat")) should not be None + find(id("sqlstat")) should not be None // check whether statements exists queries.foreach { line => diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 863a5db1bf98c..b64768ababef9 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -25,6 +25,7 @@ import org.apache.hadoop.hive.ql.parse.VariableSubstitution import org.apache.spark.sql.catalyst.ParserDialect import scala.collection.JavaConversions._ +import scala.collection.mutable.HashMap import scala.language.implicitConversions import org.apache.hadoop.fs.{FileSystem, Path} @@ -153,11 +154,11 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { * Hive 13 as this is the version of Hive that is packaged with Spark SQL. This copy of the * client is used for execution related tasks like registering temporary functions or ensuring * that the ThreadLocal SessionState is correctly populated. This copy of Hive is *not* used - * for storing peristent metadata, and only point to a dummy metastore in a temporary directory. + * for storing persistent metadata, and only point to a dummy metastore in a temporary directory. */ @transient protected[hive] lazy val executionHive: ClientWrapper = { - logInfo(s"Initilizing execution hive, version $hiveExecutionVersion") + logInfo(s"Initializing execution hive, version $hiveExecutionVersion") new ClientWrapper( version = IsolatedClientLoader.hiveVersion(hiveExecutionVersion), config = newTemporaryConfiguration()) @@ -372,6 +373,10 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { ResolveHiveWindowFunction :: sources.PreInsertCastAndRename :: Nil + + override val extendedCheckRules = Seq( + sources.PreWriteCheck(catalog) + ) } override protected[sql] def createSession(): SQLSession = { @@ -507,8 +512,19 @@ private[hive] object HiveContext { def newTemporaryConfiguration(): Map[String, String] = { val tempDir = Utils.createTempDir() val localMetastore = new File(tempDir, "metastore").getAbsolutePath - Map( - "javax.jdo.option.ConnectionURL" -> s"jdbc:derby:;databaseName=$localMetastore;create=true") + val propMap: HashMap[String, String] = HashMap() + // We have to mask all properties in hive-site.xml that relates to metastore data source + // as we used a local metastore here. + HiveConf.ConfVars.values().foreach { confvar => + if (confvar.varname.contains("datanucleus") || confvar.varname.contains("jdo")) { + propMap.put(confvar.varname, confvar.defaultVal) + } + } + propMap.put("javax.jdo.option.ConnectionURL", + s"jdbc:derby:;databaseName=$localMetastore;create=true") + propMap.put("datanucleus.rdbms.datastoreAdapterClassName", + "org.datanucleus.store.rdbms.adapter.DerbyAdapter") + propMap.toMap } protected val primitiveTypes = diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 5b6840008f1ce..425a4005aa2c3 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -516,17 +516,19 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive def castChildOutput(p: InsertIntoTable, table: MetastoreRelation, child: LogicalPlan) : LogicalPlan = { val childOutputDataTypes = child.output.map(_.dataType) + val numDynamicPartitions = p.partition.values.count(_.isEmpty) val tableOutputDataTypes = - (table.attributes ++ table.partitionKeys).take(child.output.length).map(_.dataType) + (table.attributes ++ table.partitionKeys.takeRight(numDynamicPartitions)) + .take(child.output.length).map(_.dataType) if (childOutputDataTypes == tableOutputDataTypes) { - p + InsertIntoHiveTable(table, p.partition, p.child, p.overwrite, p.ifNotExists) } else if (childOutputDataTypes.size == tableOutputDataTypes.size && childOutputDataTypes.zip(tableOutputDataTypes) .forall { case (left, right) => left.sameType(right) }) { // If both types ignoring nullability of ArrayType, MapType, StructType are the same, // use InsertIntoHiveTable instead of InsertIntoTable. - InsertIntoHiveTable(p.table, p.partition, p.child, p.overwrite, p.ifNotExists) + InsertIntoHiveTable(table, p.partition, p.child, p.overwrite, p.ifNotExists) } else { // Only do the casting when child output data types differ from table output data types. val castedChildOutput = child.output.zip(table.output).map { @@ -561,7 +563,7 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive * because Hive table doesn't have nullability for ARRAY, MAP, STRUCT types. */ private[hive] case class InsertIntoHiveTable( - table: LogicalPlan, + table: MetastoreRelation, partition: Map[String, Option[String]], child: LogicalPlan, overwrite: Boolean, @@ -571,7 +573,13 @@ private[hive] case class InsertIntoHiveTable( override def children: Seq[LogicalPlan] = child :: Nil override def output: Seq[Attribute] = child.output - override lazy val resolved: Boolean = childrenResolved && child.output.zip(table.output).forall { + val numDynamicPartitions = partition.values.count(_.isEmpty) + + // This is the expected schema of the table prepared to be inserted into, + // including dynamic partition columns. + val tableOutput = table.attributes ++ table.partitionKeys.takeRight(numDynamicPartitions) + + override lazy val resolved: Boolean = childrenResolved && child.output.zip(tableOutput).forall { case (childAttr, tableAttr) => childAttr.dataType.sameType(tableAttr.dataType) } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index 964828407481e..2e06cabfa80c9 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -82,9 +82,11 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { lazy val warehousePath = Utils.createTempDir() + private lazy val temporaryConfig = newTemporaryConfiguration() + /** Sets up the system initially or after a RESET command */ protected override def configure(): Map[String, String] = - newTemporaryConfiguration() ++ Map("hive.metastore.warehouse.dir" -> warehousePath.toString) + temporaryConfig ++ Map("hive.metastore.warehouse.dir" -> warehousePath.toString) val testTempDir = Utils.createTempDir() diff --git a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaDataFrameSuite.java b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaDataFrameSuite.java new file mode 100644 index 0000000000000..c4828c4717643 --- /dev/null +++ b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaDataFrameSuite.java @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test.org.apache.spark.sql.hive; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.*; +import org.apache.spark.sql.expressions.Window; +import org.apache.spark.sql.hive.HiveContext; +import org.apache.spark.sql.hive.test.TestHive$; + +public class JavaDataFrameSuite { + private transient JavaSparkContext sc; + private transient HiveContext hc; + + DataFrame df; + + private void checkAnswer(DataFrame actual, List expected) { + String errorMessage = QueryTest$.MODULE$.checkAnswer(actual, expected); + if (errorMessage != null) { + Assert.fail(errorMessage); + } + } + + @Before + public void setUp() throws IOException { + hc = TestHive$.MODULE$; + sc = new JavaSparkContext(hc.sparkContext()); + + List jsonObjects = new ArrayList(10); + for (int i = 0; i < 10; i++) { + jsonObjects.add("{\"key\":" + i + ", \"value\":\"str" + i + "\"}"); + } + df = hc.jsonRDD(sc.parallelize(jsonObjects)); + df.registerTempTable("window_table"); + } + + @After + public void tearDown() throws IOException { + // Clean up tables. + hc.sql("DROP TABLE IF EXISTS window_table"); + } + + @Test + public void saveTableAndQueryIt() { + checkAnswer( + df.select(functions.avg("key").over( + Window.partitionBy("value").orderBy("key").rowsBetween(-1, 1))), + hc.sql("SELECT avg(key) " + + "OVER (PARTITION BY value " + + " ORDER BY key " + + " ROWS BETWEEN 1 preceding and 1 following) " + + "FROM window_table").collectAsList()); + } +} diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java similarity index 98% rename from sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java rename to sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java index 58fe96adab17e..64d1ce92931eb 100644 --- a/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java +++ b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java @@ -14,7 +14,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql.hive; + +package test.org.apache.spark.sql.hive; import java.io.File; import java.io.IOException; @@ -36,6 +37,7 @@ import org.apache.spark.sql.DataFrame; import org.apache.spark.sql.QueryTest$; import org.apache.spark.sql.Row; +import org.apache.spark.sql.hive.HiveContext; import org.apache.spark.sql.hive.test.TestHive$; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.StructField; diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFIntegerToString.java b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/execution/UDFIntegerToString.java similarity index 100% rename from sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFIntegerToString.java rename to sql/hive/src/test/java/test/org/apache/spark/sql/hive/execution/UDFIntegerToString.java diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFListListInt.java b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/execution/UDFListListInt.java similarity index 100% rename from sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFListListInt.java rename to sql/hive/src/test/java/test/org/apache/spark/sql/hive/execution/UDFListListInt.java diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFListString.java b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/execution/UDFListString.java similarity index 100% rename from sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFListString.java rename to sql/hive/src/test/java/test/org/apache/spark/sql/hive/execution/UDFListString.java diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFStringString.java b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/execution/UDFStringString.java similarity index 100% rename from sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFStringString.java rename to sql/hive/src/test/java/test/org/apache/spark/sql/hive/execution/UDFStringString.java diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFTwoListList.java b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/execution/UDFTwoListList.java similarity index 100% rename from sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFTwoListList.java rename to sql/hive/src/test/java/test/org/apache/spark/sql/hive/execution/UDFTwoListList.java diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameWindowSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameWindowSuite.scala new file mode 100644 index 0000000000000..efb3f2545db84 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameWindowSuite.scala @@ -0,0 +1,215 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import org.apache.spark.sql.{Row, QueryTest} +import org.apache.spark.sql.expressions.Window +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.hive.test.TestHive._ +import org.apache.spark.sql.hive.test.TestHive.implicits._ + +class HiveDataFrameWindowSuite extends QueryTest { + + test("reuse window partitionBy") { + val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") + val w = Window.partitionBy("key").orderBy("value") + + checkAnswer( + df.select( + lead("key", 1).over(w), + lead("value", 1).over(w)), + Row(1, "1") :: Row(2, "2") :: Row(null, null) :: Row(null, null) :: Nil) + } + + test("reuse window orderBy") { + val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") + val w = Window.orderBy("value").partitionBy("key") + + checkAnswer( + df.select( + lead("key", 1).over(w), + lead("value", 1).over(w)), + Row(1, "1") :: Row(2, "2") :: Row(null, null) :: Row(null, null) :: Nil) + } + + test("lead") { + val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") + df.registerTempTable("window_table") + + checkAnswer( + df.select( + lead("value", 1).over(Window.partitionBy($"key").orderBy($"value"))), + sql( + """SELECT + | lead(value) OVER (PARTITION BY key ORDER BY value) + | FROM window_table""".stripMargin).collect()) + } + + test("lag") { + val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") + df.registerTempTable("window_table") + + checkAnswer( + df.select( + lag("value", 1).over(Window.partitionBy($"key").orderBy($"value"))), + sql( + """SELECT + | lag(value) OVER (PARTITION BY key ORDER BY value) + | FROM window_table""".stripMargin).collect()) + } + + test("lead with default value") { + val df = Seq((1, "1"), (1, "1"), (2, "2"), (1, "1"), + (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") + df.registerTempTable("window_table") + checkAnswer( + df.select( + lead("value", 2, "n/a").over(Window.partitionBy("key").orderBy("value"))), + sql( + """SELECT + | lead(value, 2, "n/a") OVER (PARTITION BY key ORDER BY value) + | FROM window_table""".stripMargin).collect()) + } + + test("lag with default value") { + val df = Seq((1, "1"), (1, "1"), (2, "2"), (1, "1"), + (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") + df.registerTempTable("window_table") + checkAnswer( + df.select( + lag("value", 2, "n/a").over(Window.partitionBy($"key").orderBy($"value"))), + sql( + """SELECT + | lag(value, 2, "n/a") OVER (PARTITION BY key ORDER BY value) + | FROM window_table""".stripMargin).collect()) + } + + test("rank functions in unspecific window") { + val df = Seq((1, "1"), (2, "2"), (1, "2"), (2, "2")).toDF("key", "value") + df.registerTempTable("window_table") + checkAnswer( + df.select( + $"key", + max("key").over(Window.partitionBy("value").orderBy("key")), + min("key").over(Window.partitionBy("value").orderBy("key")), + mean("key").over(Window.partitionBy("value").orderBy("key")), + count("key").over(Window.partitionBy("value").orderBy("key")), + sum("key").over(Window.partitionBy("value").orderBy("key")), + ntile(2).over(Window.partitionBy("value").orderBy("key")), + rowNumber().over(Window.partitionBy("value").orderBy("key")), + denseRank().over(Window.partitionBy("value").orderBy("key")), + rank().over(Window.partitionBy("value").orderBy("key")), + cumeDist().over(Window.partitionBy("value").orderBy("key")), + percentRank().over(Window.partitionBy("value").orderBy("key"))), + sql( + s"""SELECT + |key, + |max(key) over (partition by value order by key), + |min(key) over (partition by value order by key), + |avg(key) over (partition by value order by key), + |count(key) over (partition by value order by key), + |sum(key) over (partition by value order by key), + |ntile(2) over (partition by value order by key), + |row_number() over (partition by value order by key), + |dense_rank() over (partition by value order by key), + |rank() over (partition by value order by key), + |cume_dist() over (partition by value order by key), + |percent_rank() over (partition by value order by key) + |FROM window_table""".stripMargin).collect()) + } + + test("aggregation and rows between") { + val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") + df.registerTempTable("window_table") + checkAnswer( + df.select( + avg("key").over(Window.partitionBy($"value").orderBy($"key").rowsBetween(-1, 2))), + sql( + """SELECT + | avg(key) OVER + | (PARTITION BY value ORDER BY key ROWS BETWEEN 1 preceding and 2 following) + | FROM window_table""".stripMargin).collect()) + } + + test("aggregation and range betweens") { + val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") + df.registerTempTable("window_table") + checkAnswer( + df.select( + avg("key").over(Window.partitionBy($"value").orderBy($"key").rangeBetween(-1, 1))), + sql( + """SELECT + | avg(key) OVER + | (PARTITION BY value ORDER BY key RANGE BETWEEN 1 preceding and 1 following) + | FROM window_table""".stripMargin).collect()) + } + + test("aggregation and rows betweens with unbounded") { + val df = Seq((1, "1"), (2, "2"), (2, "3"), (1, "3"), (3, "2"), (4, "3")).toDF("key", "value") + df.registerTempTable("window_table") + checkAnswer( + df.select( + $"key", + last("value").over( + Window.partitionBy($"value").orderBy($"key").rowsBetween(0, Long.MaxValue)), + last("value").over( + Window.partitionBy($"value").orderBy($"key").rowsBetween(Long.MinValue, 0)), + last("value").over(Window.partitionBy($"value").orderBy($"key").rowsBetween(-1, 3))), + sql( + """SELECT + | key, + | last_value(value) OVER + | (PARTITION BY value ORDER BY key ROWS between current row and unbounded following), + | last_value(value) OVER + | (PARTITION BY value ORDER BY key ROWS between unbounded preceding and current row), + | last_value(value) OVER + | (PARTITION BY value ORDER BY key ROWS between 1 preceding and 3 following) + | FROM window_table""".stripMargin).collect()) + } + + test("aggregation and range betweens with unbounded") { + val df = Seq((1, "1"), (2, "2"), (2, "2"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") + df.registerTempTable("window_table") + checkAnswer( + df.select( + $"key", + last("value").over( + Window.partitionBy($"value").orderBy($"key").rangeBetween(1, Long.MaxValue)) + .equalTo("2") + .as("last_v"), + avg("key").over(Window.partitionBy("value").orderBy("key").rangeBetween(Long.MinValue, 1)) + .as("avg_key1"), + avg("key").over(Window.partitionBy("value").orderBy("key").rangeBetween(0, Long.MaxValue)) + .as("avg_key2"), + avg("key").over(Window.partitionBy("value").orderBy("key").rangeBetween(-1, 0)) + .as("avg_key3") + ), + sql( + """SELECT + | key, + | last_value(value) OVER + | (PARTITION BY value ORDER BY key RANGE 1 preceding) == "2", + | avg(key) OVER + | (PARTITION BY value ORDER BY key RANGE BETWEEN unbounded preceding and 1 following), + | avg(key) OVER + | (PARTITION BY value ORDER BY key RANGE BETWEEN current row and unbounded following), + | avg(key) OVER + | (PARTITION BY value ORDER BY key RANGE BETWEEN 1 preceding and current row) + | FROM window_table""".stripMargin).collect()) + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala index ecb990e8aac91..acf2f7da30188 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala @@ -53,7 +53,7 @@ class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter { sql("CREATE TABLE createAndInsertTest (key int, value string)") // Add some data. - testData.insertInto("createAndInsertTest") + testData.write.mode(SaveMode.Append).insertInto("createAndInsertTest") // Make sure the table has also been updated. checkAnswer( @@ -62,7 +62,7 @@ class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter { ) // Add more data. - testData.insertInto("createAndInsertTest") + testData.write.mode(SaveMode.Append).insertInto("createAndInsertTest") // Make sure the table has been updated. checkAnswer( @@ -71,7 +71,7 @@ class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter { ) // Now overwrite. - testData.insertInto("createAndInsertTest", overwrite = true) + testData.write.mode(SaveMode.Overwrite).insertInto("createAndInsertTest") // Make sure the registered table has also been updated. checkAnswer( diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index c4c7b634964ed..9623ef06aa9b0 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -608,7 +608,7 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach { StructType( StructField("a", ArrayType(IntegerType, containsNull = false), nullable = true) :: Nil) assert(df2.schema === expectedSchema2) - df2.insertInto("arrayInParquet", overwrite = false) + df2.write.mode(SaveMode.Append).insertInto("arrayInParquet") createDataFrame(Tuple1(Seq(4, 5)) :: Nil).toDF("a").write.mode(SaveMode.Append) .saveAsTable("arrayInParquet") // This one internally calls df2.insertInto. createDataFrame(Tuple1(Seq(Int.box(6), null.asInstanceOf[Integer])) :: Nil).toDF("a").write @@ -642,7 +642,7 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach { StructType( StructField("a", mapType2, nullable = true) :: Nil) assert(df2.schema === expectedSchema2) - df2.insertInto("mapInParquet", overwrite = false) + df2.write.mode(SaveMode.Append).insertInto("mapInParquet") createDataFrame(Tuple1(Map(4 -> 5)) :: Nil).toDF("a").write.mode(SaveMode.Append) .saveAsTable("mapInParquet") // This one internally calls df2.insertInto. createDataFrame(Tuple1(Map(6 -> null.asInstanceOf[Integer])) :: Nil).toDF("a").write @@ -768,7 +768,7 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach { sql("SELECT p.c1, c2 FROM insertParquet p WHERE p.c1 > 5 AND p.c1 < 35"), (6 to 34).map(i => Row(i, s"str$i"))) - createDF(40, 49).insertInto("insertParquet") + createDF(40, 49).write.mode(SaveMode.Append).insertInto("insertParquet") checkAnswer( sql("SELECT p.c1, c2 FROM insertParquet p WHERE p.c1 > 5 AND p.c1 < 45"), (6 to 44).map(i => Row(i, s"str$i"))) @@ -782,7 +782,7 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach { sql("SELECT p.c1, c2 FROM insertParquet p"), (50 to 59).map(i => Row(i, s"str$i"))) - createDF(70, 79).insertInto("insertParquet", overwrite = true) + createDF(70, 79).write.mode(SaveMode.Overwrite).insertInto("insertParquet") checkAnswer( sql("SELECT p.c1, c2 FROM insertParquet p"), (70 to 79).map(i => Row(i, s"str$i"))) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 65c6ef03bf041..4af31d482ce42 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -28,6 +28,7 @@ import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.spark.{SparkFiles, SparkException} import org.apache.spark.sql.{AnalysisException, DataFrame, Row} +import org.apache.spark.sql.catalyst.expressions.Cast import org.apache.spark.sql.catalyst.plans.logical.Project import org.apache.spark.sql.hive._ import org.apache.spark.sql.hive.test.TestHive @@ -415,6 +416,25 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { |SELECT * FROM createdtable; """.stripMargin) + test("SPARK-7270: consider dynamic partition when comparing table output") { + sql(s"CREATE TABLE test_partition (a STRING) PARTITIONED BY (b BIGINT, c STRING)") + sql(s"CREATE TABLE ptest (a STRING, b BIGINT, c STRING)") + + val analyzedPlan = sql( + """ + |INSERT OVERWRITE table test_partition PARTITION (b=1, c) + |SELECT 'a', 'c' from ptest + """.stripMargin).queryExecution.analyzed + + assertResult(false, "Incorrect cast detected\n" + analyzedPlan) { + var hasCast = false + analyzedPlan.collect { + case p: Project => p.transformExpressionsUp { case c: Cast => hasCast = true; c } + } + hasCast + } + } + createQueryTest("transform", "SELECT TRANSFORM (key) USING 'cat' AS (tKey) FROM src") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index ba53ed99beb03..b707f5e68489b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.hive.execution import org.apache.spark.sql.catalyst.DefaultParserDialect import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries import org.apache.spark.sql.catalyst.errors.DialectException -import org.apache.spark.sql.{AnalysisException, QueryTest, Row, SQLConf} +import org.apache.spark.sql._ import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.sql.hive.test.TestHive.implicits._ @@ -425,10 +425,10 @@ class SQLQuerySuite extends QueryTest { test("SPARK-4825 save join to table") { val testData = sparkContext.parallelize(1 to 10).map(i => TestData(i, i.toString)).toDF() sql("CREATE TABLE test1 (key INT, value STRING)") - testData.insertInto("test1") + testData.write.mode(SaveMode.Append).insertInto("test1") sql("CREATE TABLE test2 (key INT, value STRING)") - testData.insertInto("test2") - testData.insertInto("test2") + testData.write.mode(SaveMode.Append).insertInto("test2") + testData.write.mode(SaveMode.Append).insertInto("test2") sql("CREATE TABLE test AS SELECT COUNT(a.value) FROM test1 a JOIN test2 b ON a.key = b.key") checkAnswer( table("test"), diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala index 223ba65f47b90..7851f38fd4056 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala @@ -316,7 +316,7 @@ class ParquetDataSourceOnMetastoreSuite extends ParquetMetastoreSuiteBase { val df = sql("INSERT INTO TABLE test_insert_parquet SELECT a FROM jt") df.queryExecution.executedPlan match { - case ExecutedCommand(InsertIntoHadoopFsRelation(_: ParquetRelation2, _, _, _)) => // OK + case ExecutedCommand(InsertIntoHadoopFsRelation(_: ParquetRelation2, _, _)) => // OK case o => fail("test_insert_parquet should be converted to a " + s"${classOf[ParquetRelation2].getCanonicalName} and " + s"${classOf[InsertIntoDataSource].getCanonicalName} is expcted as the SparkPlan. " + @@ -346,7 +346,7 @@ class ParquetDataSourceOnMetastoreSuite extends ParquetMetastoreSuiteBase { val df = sql("INSERT INTO TABLE test_insert_parquet SELECT a FROM jt_array") df.queryExecution.executedPlan match { - case ExecutedCommand(InsertIntoHadoopFsRelation(r: ParquetRelation2, _, _, _)) => // OK + case ExecutedCommand(InsertIntoHadoopFsRelation(r: ParquetRelation2, _, _)) => // OK case o => fail("test_insert_parquet should be converted to a " + s"${classOf[ParquetRelation2].getCanonicalName} and " + s"${classOf[InsertIntoDataSource].getCanonicalName} is expcted as the SparkPlan." + diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala index 2d69b89fd9a9c..de907846b9180 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala @@ -28,7 +28,7 @@ import org.apache.hadoop.mapreduce.{Job, RecordWriter, TaskAttemptContext} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.expressions.{Cast, Literal} -import org.apache.spark.sql.types.{DataType, StructField, StructType} +import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.sql.{Row, SQLContext} /** @@ -67,7 +67,9 @@ class SimpleTextOutputWriter(path: String, context: TaskAttemptContext) extends recordWriter.write(null, new Text(serialized)) } - override def close(): Unit = recordWriter.close(context) + override def close(): Unit = { + recordWriter.close(context) + } } /** @@ -120,3 +122,39 @@ class SimpleTextRelation( } } } + +/** + * A simple example [[HadoopFsRelationProvider]]. + */ +class CommitFailureTestSource extends HadoopFsRelationProvider { + override def createRelation( + sqlContext: SQLContext, + paths: Array[String], + schema: Option[StructType], + partitionColumns: Option[StructType], + parameters: Map[String, String]): HadoopFsRelation = { + new CommitFailureTestRelation(paths, schema, partitionColumns, parameters)(sqlContext) + } +} + +class CommitFailureTestRelation( + override val paths: Array[String], + maybeDataSchema: Option[StructType], + override val userDefinedPartitionColumns: Option[StructType], + parameters: Map[String, String])( + @transient sqlContext: SQLContext) + extends SimpleTextRelation( + paths, maybeDataSchema, userDefinedPartitionColumns, parameters)(sqlContext) { + override def prepareJobForWrite(job: Job): OutputWriterFactory = new OutputWriterFactory { + override def newInstance( + path: String, + dataSchema: StructType, + context: TaskAttemptContext): OutputWriter = { + new SimpleTextOutputWriter(path, context) { + override def close(): Unit = { + sys.error("Intentional task commitment failure for testing purpose.") + } + } + } + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala index c7c8bcd27fbde..70328e1ef810d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala @@ -18,7 +18,9 @@ package org.apache.spark.sql.sources import org.apache.hadoop.fs.Path +import org.scalatest.FunSuite +import org.apache.spark.SparkException import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.sql._ import org.apache.spark.sql.hive.test.TestHive @@ -362,16 +364,6 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils { .partitionBy("p1") .saveAsTable("t") } - - // Using different order of partition columns - intercept[Throwable] { - partitionedTestDF2.write - .format(dataSourceName) - .mode(SaveMode.Append) - .option("dataSchema", dataSchema.json) - .partitionBy("p2", "p1") - .saveAsTable("t") - } } test("saveAsTable()/load() - partitioned table - ErrorIfExists") { @@ -487,6 +479,26 @@ class SimpleTextHadoopFsRelationSuite extends HadoopFsRelationTest { } } +class CommitFailureTestRelationSuite extends FunSuite with SQLTestUtils { + import TestHive.implicits._ + + override val sqlContext = TestHive + + val dataSourceName: String = classOf[CommitFailureTestSource].getCanonicalName + + test("SPARK-7684: commitTask() failure should fallback to abortTask()") { + withTempPath { file => + val df = (1 to 3).map(i => i -> s"val_$i").toDF("a", "b") + intercept[SparkException] { + df.write.format(dataSourceName).save(file.getCanonicalPath) + } + + val fs = new Path(file.getCanonicalPath).getFileSystem(SparkHadoopUtil.get.conf) + assert(!fs.exists(new Path(file.getCanonicalPath, "_temporary"))) + } + } +} + class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest { override val dataSourceName: String = classOf[parquet.DefaultSource].getCanonicalName diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala index 7b77d447ce6df..5e58ed714829e 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala @@ -262,7 +262,7 @@ class StreamingContext private[streaming] ( * * Note: Return statements are NOT allowed in the given body. */ - private def withNamedScope[U](name: String)(body: => U): U = { + private[streaming] def withNamedScope[U](name: String)(body: => U): U = { RDDOperationScope.withScope(sc, name, allowNesting = false, ignoreParent = false)(body) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala index 5cfe43a1ce726..e4ff05e12f201 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala @@ -73,27 +73,38 @@ abstract class ReceiverInputDStream[T: ClassTag](@transient ssc_ : StreamingCont val inputInfo = InputInfo(id, blockInfos.map(_.numRecords).sum) ssc.scheduler.inputInfoTracker.reportInfo(validTime, inputInfo) - // Are WAL record handles present with all the blocks - val areWALRecordHandlesPresent = blockInfos.forall { _.walRecordHandleOption.nonEmpty } + if (blockInfos.nonEmpty) { + // Are WAL record handles present with all the blocks + val areWALRecordHandlesPresent = blockInfos.forall { _.walRecordHandleOption.nonEmpty } - if (areWALRecordHandlesPresent) { - // If all the blocks have WAL record handle, then create a WALBackedBlockRDD - val isBlockIdValid = blockInfos.map { _.isBlockIdValid() }.toArray - val walRecordHandles = blockInfos.map { _.walRecordHandleOption.get }.toArray - new WriteAheadLogBackedBlockRDD[T]( - ssc.sparkContext, blockIds, walRecordHandles, isBlockIdValid) - } else { - // Else, create a BlockRDD. However, if there are some blocks with WAL info but not others - // then that is unexpected and log a warning accordingly. - if (blockInfos.find(_.walRecordHandleOption.nonEmpty).nonEmpty) { - if (WriteAheadLogUtils.enableReceiverLog(ssc.conf)) { - logError("Some blocks do not have Write Ahead Log information; " + - "this is unexpected and data may not be recoverable after driver failures") - } else { - logWarning("Some blocks have Write Ahead Log information; this is unexpected") + if (areWALRecordHandlesPresent) { + // If all the blocks have WAL record handle, then create a WALBackedBlockRDD + val isBlockIdValid = blockInfos.map { _.isBlockIdValid() }.toArray + val walRecordHandles = blockInfos.map { _.walRecordHandleOption.get }.toArray + new WriteAheadLogBackedBlockRDD[T]( + ssc.sparkContext, blockIds, walRecordHandles, isBlockIdValid) + } else { + // Else, create a BlockRDD. However, if there are some blocks with WAL info but not + // others then that is unexpected and log a warning accordingly. + if (blockInfos.find(_.walRecordHandleOption.nonEmpty).nonEmpty) { + if (WriteAheadLogUtils.enableReceiverLog(ssc.conf)) { + logError("Some blocks do not have Write Ahead Log information; " + + "this is unexpected and data may not be recoverable after driver failures") + } else { + logWarning("Some blocks have Write Ahead Log information; this is unexpected") + } } + new BlockRDD[T](ssc.sc, blockIds) + } + } else { + // If no block is ready now, creating WriteAheadLogBackedBlockRDD or BlockRDD + // according to the configuration + if (WriteAheadLogUtils.enableReceiverLog(ssc.conf)) { + new WriteAheadLogBackedBlockRDD[T]( + ssc.sparkContext, Array.empty, Array.empty, Array.empty) + } else { + new BlockRDD[T](ssc.sc, Array.empty) } - new BlockRDD[T](ssc.sc, blockIds) } } } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala index 93e6b0cd7c661..0122514f9374c 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala @@ -39,6 +39,7 @@ import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.scheduler.{StreamingListenerBatchCompleted, StreamingListener} import org.apache.spark.util.{ManualClock, Utils} import org.apache.spark.streaming.dstream.{InputDStream, ReceiverInputDStream} +import org.apache.spark.streaming.rdd.WriteAheadLogBackedBlockRDD import org.apache.spark.streaming.receiver.Receiver class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { @@ -105,6 +106,36 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { } } + test("socket input stream - no block in a batch") { + withTestServer(new TestServer()) { testServer => + testServer.start() + + withStreamingContext(new StreamingContext(conf, batchDuration)) { ssc => + ssc.addStreamingListener(ssc.progressListener) + + val batchCounter = new BatchCounter(ssc) + val networkStream = ssc.socketTextStream( + "localhost", testServer.port, StorageLevel.MEMORY_AND_DISK) + val outputBuffer = new ArrayBuffer[Seq[String]] with SynchronizedBuffer[Seq[String]] + val outputStream = new TestOutputStream(networkStream, outputBuffer) + outputStream.register() + ssc.start() + + val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] + clock.advance(batchDuration.milliseconds) + + // Make sure the first batch is finished + if (!batchCounter.waitUntilBatchesCompleted(1, 30000)) { + fail("Timeout: cannot finish all batches in 30 seconds") + } + + networkStream.generatedRDDs.foreach { case (_, rdd) => + assert(!rdd.isInstanceOf[WriteAheadLogBackedBlockRDD[_]]) + } + } + } + } + test("binary records stream") { val testDir: File = null try {