diff --git a/.gitignore b/.gitignore index d162fa9cca994..d54d21b802be8 100644 --- a/.gitignore +++ b/.gitignore @@ -63,6 +63,8 @@ ec2/lib/ rat-results.txt scalastyle.txt scalastyle-output.xml +R-unit-tests.log +R/unit-tests.out # For Hive metastore_db/ diff --git a/.rat-excludes b/.rat-excludes index 8c61e67a0c7d1..8aca5a7f7a967 100644 --- a/.rat-excludes +++ b/.rat-excludes @@ -67,3 +67,5 @@ logs .*scalastyle-output.xml .*dependency-reduced-pom.xml known_translations +DESCRIPTION +NAMESPACE diff --git a/R/.gitignore b/R/.gitignore new file mode 100644 index 0000000000000..9a5889ba28b2a --- /dev/null +++ b/R/.gitignore @@ -0,0 +1,6 @@ +*.o +*.so +*.Rd +lib +pkg/man +pkg/html diff --git a/R/DOCUMENTATION.md b/R/DOCUMENTATION.md new file mode 100644 index 0000000000000..931d01549b265 --- /dev/null +++ b/R/DOCUMENTATION.md @@ -0,0 +1,12 @@ +# SparkR Documentation + +SparkR documentation is generated using in-source comments annotated using using +`roxygen2`. After making changes to the documentation, to generate man pages, +you can run the following from an R console in the SparkR home directory + + library(devtools) + devtools::document(pkg="./pkg", roclets=c("rd")) + +You can verify if your changes are good by running + + R CMD check pkg/ diff --git a/R/README.md b/R/README.md new file mode 100644 index 0000000000000..a6970e39b55f3 --- /dev/null +++ b/R/README.md @@ -0,0 +1,67 @@ +# R on Spark + +SparkR is an R package that provides a light-weight frontend to use Spark from R. + +### SparkR development + +#### Build Spark + +Build Spark with [Maven](http://spark.apache.org/docs/latest/building-spark.html#building-with-buildmvn) and include the `-PsparkR` profile to build the R package. For example to use the default Hadoop versions you can run +``` + build/mvn -DskipTests -Psparkr package +``` + +#### Running sparkR + +You can start using SparkR by launching the SparkR shell with + + ./bin/sparkR + +The `sparkR` script automatically creates a SparkContext with Spark by default in +local mode. To specify the Spark master of a cluster for the automatically created +SparkContext, you can run + + ./bin/sparkR --master "local[2]" + +To set other options like driver memory, executor memory etc. you can pass in the [spark-submit](http://spark.apache.org/docs/latest/submitting-applications.html) arguments to `./bin/sparkR` + +#### Using SparkR from RStudio + +If you wish to use SparkR from RStudio or other R frontends you will need to set some environment variables which point SparkR to your Spark installation. For example +``` +# Set this to where Spark is installed +Sys.setenv(SPARK_HOME="/Users/shivaram/spark") +# This line loads SparkR from the installed directory +.libPaths(c(file.path(Sys.getenv("SPARK_HOME"), "R", "lib"), .libPaths())) +library(SparkR) +sc <- sparkR.init(master="local") +``` + +#### Making changes to SparkR + +The [instructions](https://cwiki.apache.org/confluence/display/SPARK/Contributing+to+Spark) for making contributions to Spark also apply to SparkR. +If you only make R file changes (i.e. no Scala changes) then you can just re-install the R package using `R/install-dev.sh` and test your changes. +Once you have made your changes, please include unit tests for them and run existing unit tests using the `run-tests.sh` script as described below. + +#### Generating documentation + +The SparkR documentation (Rd files and HTML files) are not a part of the source repository. To generate them you can run the script `R/create-docs.sh`. This script uses `devtools` and `knitr` to generate the docs and these packages need to be installed on the machine before using the script. + +### Examples, Unit tests + +SparkR comes with several sample programs in the `examples/src/main/r` directory. +To run one of them, use `./bin/sparkR `. For example: + + ./bin/sparkR examples/src/main/r/pi.R local[2] + +You can also run the unit-tests for SparkR by running (you need to install the [testthat](http://cran.r-project.org/web/packages/testthat/index.html) package first): + + R -e 'install.packages("testthat", repos="http://cran.us.r-project.org")' + ./R/run-tests.sh + +### Running on YARN +The `./bin/spark-submit` and `./bin/sparkR` can also be used to submit jobs to YARN clusters. You will need to set YARN conf dir before doing so. For example on CDH you can run +``` +export YARN_CONF_DIR=/etc/hadoop/conf +./bin/spark-submit --master yarn examples/src/main/r/pi.R 4 +``` diff --git a/R/WINDOWS.md b/R/WINDOWS.md new file mode 100644 index 0000000000000..3f889c0ca3d1e --- /dev/null +++ b/R/WINDOWS.md @@ -0,0 +1,13 @@ +## Building SparkR on Windows + +To build SparkR on Windows, the following steps are required + +1. Install R (>= 3.1) and [Rtools](http://cran.r-project.org/bin/windows/Rtools/). Make sure to +include Rtools and R in `PATH`. +2. Install +[JDK7](http://www.oracle.com/technetwork/java/javase/downloads/jdk7-downloads-1880260.html) and set +`JAVA_HOME` in the system environment variables. +3. Download and install [Maven](http://maven.apache.org/download.html). Also include the `bin` +directory in Maven in `PATH`. +4. Set `MAVEN_OPTS` as described in [Building Spark](http://spark.apache.org/docs/latest/building-spark.html). +5. Open a command shell (`cmd`) in the Spark directory and run `mvn -DskipTests -Psparkr package` diff --git a/R/create-docs.sh b/R/create-docs.sh new file mode 100755 index 0000000000000..4194172a2e115 --- /dev/null +++ b/R/create-docs.sh @@ -0,0 +1,46 @@ +#!/bin/bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Script to create API docs for SparkR +# This requires `devtools` and `knitr` to be installed on the machine. + +# After running this script the html docs can be found in +# $SPARK_HOME/R/pkg/html + +# Figure out where the script is +export FWDIR="$(cd "`dirname "$0"`"; pwd)" +pushd $FWDIR + +# Generate Rd file +Rscript -e 'library(devtools); devtools::document(pkg="./pkg", roclets=c("rd"))' + +# Install the package +./install-dev.sh + +# Now create HTML files + +# knit_rd puts html in current working directory +mkdir -p pkg/html +pushd pkg/html + +Rscript -e 'library(SparkR, lib.loc="../../lib"); library(knitr); knit_rd("SparkR")' + +popd + +popd diff --git a/R/install-dev.bat b/R/install-dev.bat new file mode 100644 index 0000000000000..008a5c668bc45 --- /dev/null +++ b/R/install-dev.bat @@ -0,0 +1,27 @@ +@echo off + +rem +rem Licensed to the Apache Software Foundation (ASF) under one or more +rem contributor license agreements. See the NOTICE file distributed with +rem this work for additional information regarding copyright ownership. +rem The ASF licenses this file to You under the Apache License, Version 2.0 +rem (the "License"); you may not use this file except in compliance with +rem the License. You may obtain a copy of the License at +rem +rem http://www.apache.org/licenses/LICENSE-2.0 +rem +rem Unless required by applicable law or agreed to in writing, software +rem distributed under the License is distributed on an "AS IS" BASIS, +rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +rem See the License for the specific language governing permissions and +rem limitations under the License. +rem + +rem Install development version of SparkR +rem + +set SPARK_HOME=%~dp0.. + +MKDIR %SPARK_HOME%\R\lib + +R.exe CMD INSTALL --library="%SPARK_HOME%\R\lib" %SPARK_HOME%\R\pkg\ diff --git a/R/install-dev.sh b/R/install-dev.sh new file mode 100755 index 0000000000000..55ed6f4be1a4a --- /dev/null +++ b/R/install-dev.sh @@ -0,0 +1,36 @@ +#!/bin/bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# This scripts packages the SparkR source files (R and C files) and +# creates a package that can be loaded in R. The package is by default installed to +# $FWDIR/lib and the package can be loaded by using the following command in R: +# +# library(SparkR, lib.loc="$FWDIR/lib") +# +# NOTE(shivaram): Right now we use $SPARK_HOME/R/lib to be the installation directory +# to load the SparkR package on the worker nodes. + + +FWDIR="$(cd `dirname $0`; pwd)" +LIB_DIR="$FWDIR/lib" + +mkdir -p $LIB_DIR + +# Install R +R CMD INSTALL --library=$LIB_DIR $FWDIR/pkg/ diff --git a/R/log4j.properties b/R/log4j.properties new file mode 100644 index 0000000000000..701adb2a3da1d --- /dev/null +++ b/R/log4j.properties @@ -0,0 +1,28 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Set everything to be logged to the file target/unit-tests.log +log4j.rootCategory=INFO, file +log4j.appender.file=org.apache.log4j.FileAppender +log4j.appender.file.append=true +log4j.appender.file.file=R-unit-tests.log +log4j.appender.file.layout=org.apache.log4j.PatternLayout +log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n + +# Ignore messages below warning level from Jetty, because it's a bit verbose +log4j.logger.org.eclipse.jetty=WARN +org.eclipse.jetty.LEVEL=WARN diff --git a/R/pkg/DESCRIPTION b/R/pkg/DESCRIPTION new file mode 100644 index 0000000000000..1842b97d43651 --- /dev/null +++ b/R/pkg/DESCRIPTION @@ -0,0 +1,35 @@ +Package: SparkR +Type: Package +Title: R frontend for Spark +Version: 1.4.0 +Date: 2013-09-09 +Author: The Apache Software Foundation +Maintainer: Shivaram Venkataraman +Imports: + methods +Depends: + R (>= 3.0), + methods, +Suggests: + testthat +Description: R frontend for Spark +License: Apache License (== 2.0) +Collate: + 'generics.R' + 'jobj.R' + 'SQLTypes.R' + 'RDD.R' + 'pairRDD.R' + 'column.R' + 'group.R' + 'DataFrame.R' + 'SQLContext.R' + 'broadcast.R' + 'context.R' + 'deserialize.R' + 'serialize.R' + 'sparkR.R' + 'backend.R' + 'client.R' + 'utils.R' + 'zzz.R' diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE new file mode 100644 index 0000000000000..a354cdce74afa --- /dev/null +++ b/R/pkg/NAMESPACE @@ -0,0 +1,182 @@ +#exportPattern("^[[:alpha:]]+") +exportClasses("RDD") +exportClasses("Broadcast") +exportMethods( + "aggregateByKey", + "aggregateRDD", + "cache", + "checkpoint", + "coalesce", + "cogroup", + "collect", + "collectAsMap", + "collectPartition", + "combineByKey", + "count", + "countByKey", + "countByValue", + "distinct", + "Filter", + "filterRDD", + "first", + "flatMap", + "flatMapValues", + "fold", + "foldByKey", + "foreach", + "foreachPartition", + "fullOuterJoin", + "glom", + "groupByKey", + "join", + "keyBy", + "keys", + "length", + "lapply", + "lapplyPartition", + "lapplyPartitionsWithIndex", + "leftOuterJoin", + "lookup", + "map", + "mapPartitions", + "mapPartitionsWithIndex", + "mapValues", + "maximum", + "minimum", + "numPartitions", + "partitionBy", + "persist", + "pipeRDD", + "reduce", + "reduceByKey", + "reduceByKeyLocally", + "repartition", + "rightOuterJoin", + "sampleRDD", + "saveAsTextFile", + "saveAsObjectFile", + "sortBy", + "sortByKey", + "sumRDD", + "take", + "takeOrdered", + "takeSample", + "top", + "unionRDD", + "unpersist", + "value", + "values", + "zipRDD", + "zipWithIndex", + "zipWithUniqueId" + ) + +# S3 methods exported +export( + "textFile", + "objectFile", + "parallelize", + "hashCode", + "includePackage", + "broadcast", + "setBroadcastValue", + "setCheckpointDir" + ) +export("sparkR.init") +export("sparkR.stop") +export("print.jobj") +useDynLib(SparkR, stringHashCode) +importFrom(methods, setGeneric, setMethod, setOldClass) + +# SparkRSQL + +exportClasses("DataFrame") + +exportMethods("columns", + "distinct", + "dtypes", + "explain", + "filter", + "groupBy", + "head", + "insertInto", + "intersect", + "isLocal", + "limit", + "orderBy", + "names", + "printSchema", + "registerTempTable", + "repartition", + "sampleDF", + "saveAsParquetFile", + "saveAsTable", + "saveDF", + "schema", + "select", + "selectExpr", + "show", + "showDF", + "sortDF", + "subtract", + "toJSON", + "toRDD", + "unionAll", + "where", + "withColumn", + "withColumnRenamed") + +exportClasses("Column") + +exportMethods("abs", + "alias", + "approxCountDistinct", + "asc", + "avg", + "cast", + "contains", + "countDistinct", + "desc", + "endsWith", + "getField", + "getItem", + "isNotNull", + "isNull", + "last", + "like", + "lower", + "max", + "mean", + "min", + "rlike", + "sqrt", + "startsWith", + "substr", + "sum", + "sumDistinct", + "upper") + +exportClasses("GroupedData") +exportMethods("agg") + +export("sparkRSQL.init", + "sparkRHive.init") + +export("cacheTable", + "clearCache", + "createDataFrame", + "createExternalTable", + "dropTempTable", + "jsonFile", + "jsonRDD", + "loadDF", + "parquetFile", + "sql", + "table", + "tableNames", + "tables", + "toDF", + "uncacheTable") + +export("print.structType", + "print.structField") diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R new file mode 100644 index 0000000000000..feafd56909a67 --- /dev/null +++ b/R/pkg/R/DataFrame.R @@ -0,0 +1,1270 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# DataFrame.R - DataFrame class and methods implemented in S4 OO classes + +#' @include jobj.R SQLTypes.R RDD.R pairRDD.R column.R group.R +NULL + +setOldClass("jobj") + +#' @title S4 class that represents a DataFrame +#' @description DataFrames can be created using functions like +#' \code{jsonFile}, \code{table} etc. +#' @rdname DataFrame +#' @seealso jsonFile, table +#' +#' @param env An R environment that stores bookkeeping states of the DataFrame +#' @param sdf A Java object reference to the backing Scala DataFrame +#' @export +setClass("DataFrame", + slots = list(env = "environment", + sdf = "jobj")) + +setMethod("initialize", "DataFrame", function(.Object, sdf, isCached) { + .Object@env <- new.env() + .Object@env$isCached <- isCached + + .Object@sdf <- sdf + .Object +}) + +#' @rdname DataFrame +#' @export +dataFrame <- function(sdf, isCached = FALSE) { + new("DataFrame", sdf, isCached) +} + +############################ DataFrame Methods ############################################## + +#' Print Schema of a DataFrame +#' +#' Prints out the schema in tree format +#' +#' @param x A SparkSQL DataFrame +#' +#' @rdname printSchema +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' printSchema(df) +#'} +setMethod("printSchema", + signature(x = "DataFrame"), + function(x) { + schemaString <- callJMethod(schema(x)$jobj, "treeString") + cat(schemaString) + }) + +#' Get schema object +#' +#' Returns the schema of this DataFrame as a structType object. +#' +#' @param x A SparkSQL DataFrame +#' +#' @rdname schema +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' dfSchema <- schema(df) +#'} +setMethod("schema", + signature(x = "DataFrame"), + function(x) { + structType(callJMethod(x@sdf, "schema")) + }) + +#' Explain +#' +#' Print the logical and physical Catalyst plans to the console for debugging. +#' +#' @param x A SparkSQL DataFrame +#' @param extended Logical. If extended is False, explain() only prints the physical plan. +#' @rdname explain +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' explain(df, TRUE) +#'} +setMethod("explain", + signature(x = "DataFrame"), + function(x, extended = FALSE) { + queryExec <- callJMethod(x@sdf, "queryExecution") + if (extended) { + cat(callJMethod(queryExec, "toString")) + } else { + execPlan <- callJMethod(queryExec, "executedPlan") + cat(callJMethod(execPlan, "toString")) + } + }) + +#' isLocal +#' +#' Returns True if the `collect` and `take` methods can be run locally +#' (without any Spark executors). +#' +#' @param x A SparkSQL DataFrame +#' +#' @rdname isLocal +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' isLocal(df) +#'} +setMethod("isLocal", + signature(x = "DataFrame"), + function(x) { + callJMethod(x@sdf, "isLocal") + }) + +#' ShowDF +#' +#' Print the first numRows rows of a DataFrame +#' +#' @param x A SparkSQL DataFrame +#' @param numRows The number of rows to print. Defaults to 20. +#' +#' @rdname showDF +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' showDF(df) +#'} +setMethod("showDF", + signature(x = "DataFrame"), + function(x, numRows = 20) { + cat(callJMethod(x@sdf, "showString", numToInt(numRows)), "\n") + }) + +#' show +#' +#' Print the DataFrame column names and types +#' +#' @param x A SparkSQL DataFrame +#' +#' @rdname show +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' show(df) +#'} +setMethod("show", "DataFrame", + function(object) { + cols <- lapply(dtypes(object), function(l) { + paste(l, collapse = ":") + }) + s <- paste(cols, collapse = ", ") + cat(paste("DataFrame[", s, "]\n", sep = "")) + }) + +#' DataTypes +#' +#' Return all column names and their data types as a list +#' +#' @param x A SparkSQL DataFrame +#' +#' @rdname dtypes +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' dtypes(df) +#'} +setMethod("dtypes", + signature(x = "DataFrame"), + function(x) { + lapply(schema(x)$fields(), function(f) { + c(f$name(), f$dataType.simpleString()) + }) + }) + +#' Column names +#' +#' Return all column names as a list +#' +#' @param x A SparkSQL DataFrame +#' +#' @rdname columns +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' columns(df) +#'} +setMethod("columns", + signature(x = "DataFrame"), + function(x) { + sapply(schema(x)$fields(), function(f) { + f$name() + }) + }) + +#' @rdname columns +#' @export +setMethod("names", + signature(x = "DataFrame"), + function(x) { + columns(x) + }) + +#' Register Temporary Table +#' +#' Registers a DataFrame as a Temporary Table in the SQLContext +#' +#' @param x A SparkSQL DataFrame +#' @param tableName A character vector containing the name of the table +#' +#' @rdname registerTempTable +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' registerTempTable(df, "json_df") +#' new_df <- sql(sqlCtx, "SELECT * FROM json_df") +#'} +setMethod("registerTempTable", + signature(x = "DataFrame", tableName = "character"), + function(x, tableName) { + callJMethod(x@sdf, "registerTempTable", tableName) + }) + +#' insertInto +#' +#' Insert the contents of a DataFrame into a table registered in the current SQL Context. +#' +#' @param x A SparkSQL DataFrame +#' @param tableName A character vector containing the name of the table +#' @param overwrite A logical argument indicating whether or not to overwrite +#' the existing rows in the table. +#' +#' @rdname insertInto +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' df <- loadDF(sqlCtx, path, "parquet") +#' df2 <- loadDF(sqlCtx, path2, "parquet") +#' registerTempTable(df, "table1") +#' insertInto(df2, "table1", overwrite = TRUE) +#'} +setMethod("insertInto", + signature(x = "DataFrame", tableName = "character"), + function(x, tableName, overwrite = FALSE) { + callJMethod(x@sdf, "insertInto", tableName, overwrite) + }) + +#' Cache +#' +#' Persist with the default storage level (MEMORY_ONLY). +#' +#' @param x A SparkSQL DataFrame +#' +#' @rdname cache-methods +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' cache(df) +#'} +setMethod("cache", + signature(x = "DataFrame"), + function(x) { + cached <- callJMethod(x@sdf, "cache") + x@env$isCached <- TRUE + x + }) + +#' Persist +#' +#' Persist this DataFrame with the specified storage level. For details of the +#' supported storage levels, refer to +#' http://spark.apache.org/docs/latest/programming-guide.html#rdd-persistence. +#' +#' @param x The DataFrame to persist +#' @rdname persist +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' persist(df, "MEMORY_AND_DISK") +#'} +setMethod("persist", + signature(x = "DataFrame", newLevel = "character"), + function(x, newLevel) { + callJMethod(x@sdf, "persist", getStorageLevel(newLevel)) + x@env$isCached <- TRUE + x + }) + +#' Unpersist +#' +#' Mark this DataFrame as non-persistent, and remove all blocks for it from memory and +#' disk. +#' +#' @param x The DataFrame to unpersist +#' @param blocking Whether to block until all blocks are deleted +#' @rdname unpersist-methods +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' persist(df, "MEMORY_AND_DISK") +#' unpersist(df) +#'} +setMethod("unpersist", + signature(x = "DataFrame"), + function(x, blocking = TRUE) { + callJMethod(x@sdf, "unpersist", blocking) + x@env$isCached <- FALSE + x + }) + +#' Repartition +#' +#' Return a new DataFrame that has exactly numPartitions partitions. +#' +#' @param x A SparkSQL DataFrame +#' @param numPartitions The number of partitions to use. +#' @rdname repartition +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' newDF <- repartition(df, 2L) +#'} +setMethod("repartition", + signature(x = "DataFrame", numPartitions = "numeric"), + function(x, numPartitions) { + sdf <- callJMethod(x@sdf, "repartition", numToInt(numPartitions)) + dataFrame(sdf) + }) + +#' toJSON +#' +#' Convert the rows of a DataFrame into JSON objects and return an RDD where +#' each element contains a JSON string. +#' +#' @param x A SparkSQL DataFrame +#' @return A StringRRDD of JSON objects +#' @rdname tojson +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' newRDD <- toJSON(df) +#'} +setMethod("toJSON", + signature(x = "DataFrame"), + function(x) { + rdd <- callJMethod(x@sdf, "toJSON") + jrdd <- callJMethod(rdd, "toJavaRDD") + RDD(jrdd, serializedMode = "string") + }) + +#' saveAsParquetFile +#' +#' Save the contents of a DataFrame as a Parquet file, preserving the schema. Files written out +#' with this method can be read back in as a DataFrame using parquetFile(). +#' +#' @param x A SparkSQL DataFrame +#' @param path The directory where the file is saved +#' @rdname saveAsParquetFile +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' saveAsParquetFile(df, "/tmp/sparkr-tmp/") +#'} +setMethod("saveAsParquetFile", + signature(x = "DataFrame", path = "character"), + function(x, path) { + invisible(callJMethod(x@sdf, "saveAsParquetFile", path)) + }) + +#' Distinct +#' +#' Return a new DataFrame containing the distinct rows in this DataFrame. +#' +#' @param x A SparkSQL DataFrame +#' @rdname distinct +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' distinctDF <- distinct(df) +#'} +setMethod("distinct", + signature(x = "DataFrame"), + function(x) { + sdf <- callJMethod(x@sdf, "distinct") + dataFrame(sdf) + }) + +#' SampleDF +#' +#' Return a sampled subset of this DataFrame using a random seed. +#' +#' @param x A SparkSQL DataFrame +#' @param withReplacement Sampling with replacement or not +#' @param fraction The (rough) sample target fraction +#' @rdname sampleDF +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' collect(sampleDF(df, FALSE, 0.5)) +#' collect(sampleDF(df, TRUE, 0.5)) +#'} +setMethod("sampleDF", + # TODO : Figure out how to send integer as java.lang.Long to JVM so + # we can send seed as an argument through callJMethod + signature(x = "DataFrame", withReplacement = "logical", + fraction = "numeric"), + function(x, withReplacement, fraction) { + if (fraction < 0.0) stop(cat("Negative fraction value:", fraction)) + sdf <- callJMethod(x@sdf, "sample", withReplacement, fraction) + dataFrame(sdf) + }) + +#' Count +#' +#' Returns the number of rows in a DataFrame +#' +#' @param x A SparkSQL DataFrame +#' +#' @rdname count +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' count(df) +#' } +setMethod("count", + signature(x = "DataFrame"), + function(x) { + callJMethod(x@sdf, "count") + }) + +#' Collects all the elements of a Spark DataFrame and coerces them into an R data.frame. +#' +#' @param x A SparkSQL DataFrame +#' @param stringsAsFactors (Optional) A logical indicating whether or not string columns +#' should be converted to factors. FALSE by default. + +#' @rdname collect-methods +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' collected <- collect(df) +#' firstName <- collected[[1]]$name +#' } +setMethod("collect", + signature(x = "DataFrame"), + function(x, stringsAsFactors = FALSE) { + # listCols is a list of raw vectors, one per column + listCols <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "dfToCols", x@sdf) + cols <- lapply(listCols, function(col) { + objRaw <- rawConnection(col) + numRows <- readInt(objRaw) + col <- readCol(objRaw, numRows) + close(objRaw) + col + }) + names(cols) <- columns(x) + do.call(cbind.data.frame, list(cols, stringsAsFactors = stringsAsFactors)) + }) + +#' Limit +#' +#' Limit the resulting DataFrame to the number of rows specified. +#' +#' @param x A SparkSQL DataFrame +#' @param num The number of rows to return +#' @return A new DataFrame containing the number of rows specified. +#' +#' @rdname limit +#' @export +#' @examples +#' \dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' limitedDF <- limit(df, 10) +#' } +setMethod("limit", + signature(x = "DataFrame", num = "numeric"), + function(x, num) { + res <- callJMethod(x@sdf, "limit", as.integer(num)) + dataFrame(res) + }) + +# Take the first NUM rows of a DataFrame and return a the results as a data.frame + +#' @rdname take +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' take(df, 2) +#' } +setMethod("take", + signature(x = "DataFrame", num = "numeric"), + function(x, num) { + limited <- limit(x, num) + collect(limited) + }) + +#' Head +#' +#' Return the first NUM rows of a DataFrame as a data.frame. If NUM is NULL, +#' then head() returns the first 6 rows in keeping with the current data.frame +#' convention in R. +#' +#' @param x A SparkSQL DataFrame +#' @param num The number of rows to return. Default is 6. +#' @return A data.frame +#' +#' @rdname head +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' head(df) +#' } +setMethod("head", + signature(x = "DataFrame"), + function(x, num = 6L) { + # Default num is 6L in keeping with R's data.frame convention + take(x, num) + }) + +#' Return the first row of a DataFrame +#' +#' @param x A SparkSQL DataFrame +#' +#' @rdname first +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' first(df) +#' } +setMethod("first", + signature(x = "DataFrame"), + function(x) { + take(x, 1) + }) + +#' toRDD() +#' +#' Converts a Spark DataFrame to an RDD while preserving column names. +#' +#' @param x A Spark DataFrame +#' +#' @rdname DataFrame +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' rdd <- toRDD(df) +#' } +setMethod("toRDD", + signature(x = "DataFrame"), + function(x) { + jrdd <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "dfToRowRDD", x@sdf) + colNames <- callJMethod(x@sdf, "columns") + rdd <- RDD(jrdd, serializedMode = "row") + lapply(rdd, function(row) { + names(row) <- colNames + row + }) + }) + +#' GroupBy +#' +#' Groups the DataFrame using the specified columns, so we can run aggregation on them. +#' +#' @param x a DataFrame +#' @return a GroupedData +#' @seealso GroupedData +#' @rdname DataFrame +#' @export +#' @examples +#' \dontrun{ +#' # Compute the average for all numeric columns grouped by department. +#' avg(groupBy(df, "department")) +#' +#' # Compute the max age and average salary, grouped by department and gender. +#' agg(groupBy(df, "department", "gender"), salary="avg", "age" -> "max") +#' } +setMethod("groupBy", + signature(x = "DataFrame"), + function(x, ...) { + cols <- list(...) + if (length(cols) >= 1 && class(cols[[1]]) == "character") { + sgd <- callJMethod(x@sdf, "groupBy", cols[[1]], listToSeq(cols[-1])) + } else { + jcol <- lapply(cols, function(c) { c@jc }) + sgd <- callJMethod(x@sdf, "groupBy", listToSeq(jcol)) + } + groupedData(sgd) + }) + +#' Agg +#' +#' Compute aggregates by specifying a list of columns +#' +#' @rdname DataFrame +#' @export +setMethod("agg", + signature(x = "DataFrame"), + function(x, ...) { + agg(groupBy(x), ...) + }) + + +############################## RDD Map Functions ################################## +# All of the following functions mirror the existing RDD map functions, # +# but allow for use with DataFrames by first converting to an RRDD before calling # +# the requested map function. # +################################################################################### + +#' @rdname lapply +setMethod("lapply", + signature(X = "DataFrame", FUN = "function"), + function(X, FUN) { + rdd <- toRDD(X) + lapply(rdd, FUN) + }) + +#' @rdname lapply +setMethod("map", + signature(X = "DataFrame", FUN = "function"), + function(X, FUN) { + lapply(X, FUN) + }) + +#' @rdname flatMap +setMethod("flatMap", + signature(X = "DataFrame", FUN = "function"), + function(X, FUN) { + rdd <- toRDD(X) + flatMap(rdd, FUN) + }) + +#' @rdname lapplyPartition +setMethod("lapplyPartition", + signature(X = "DataFrame", FUN = "function"), + function(X, FUN) { + rdd <- toRDD(X) + lapplyPartition(rdd, FUN) + }) + +#' @rdname lapplyPartition +setMethod("mapPartitions", + signature(X = "DataFrame", FUN = "function"), + function(X, FUN) { + lapplyPartition(X, FUN) + }) + +#' @rdname foreach +setMethod("foreach", + signature(x = "DataFrame", func = "function"), + function(x, func) { + rdd <- toRDD(x) + foreach(rdd, func) + }) + +#' @rdname foreach +setMethod("foreachPartition", + signature(x = "DataFrame", func = "function"), + function(x, func) { + rdd <- toRDD(x) + foreachPartition(rdd, func) + }) + + +############################## SELECT ################################## + +getColumn <- function(x, c) { + column(callJMethod(x@sdf, "col", c)) +} + +#' @rdname select +setMethod("$", signature(x = "DataFrame"), + function(x, name) { + getColumn(x, name) + }) + +setMethod("$<-", signature(x = "DataFrame"), + function(x, name, value) { + stopifnot(class(value) == "Column") + cols <- columns(x) + if (name %in% cols) { + cols <- lapply(cols, function(c) { + if (c == name) { + alias(value, name) + } else { + col(c) + } + }) + nx <- select(x, cols) + } else { + nx <- withColumn(x, name, value) + } + x@sdf <- nx@sdf + x + }) + +#' @rdname select +setMethod("[[", signature(x = "DataFrame"), + function(x, i) { + if (is.numeric(i)) { + cols <- columns(x) + i <- cols[[i]] + } + getColumn(x, i) + }) + +#' @rdname select +setMethod("[", signature(x = "DataFrame", i = "missing"), + function(x, i, j, ...) { + if (is.numeric(j)) { + cols <- columns(x) + j <- cols[j] + } + if (length(j) > 1) { + j <- as.list(j) + } + select(x, j) + }) + +#' Select +#' +#' Selects a set of columns with names or Column expressions. +#' @param x A DataFrame +#' @param col A list of columns or single Column or name +#' @return A new DataFrame with selected columns +#' @export +#' @rdname select +#' @examples +#' \dontrun{ +#' select(df, "*") +#' select(df, "col1", "col2") +#' select(df, df$name, df$age + 1) +#' select(df, c("col1", "col2")) +#' select(df, list(df$name, df$age + 1)) +#' # Columns can also be selected using `[[` and `[` +#' df[[2]] == df[["age"]] +#' df[,2] == df[,"age"] +#' # Similar to R data frames columns can also be selected using `$` +#' df$age +#' } +setMethod("select", signature(x = "DataFrame", col = "character"), + function(x, col, ...) { + sdf <- callJMethod(x@sdf, "select", col, toSeq(...)) + dataFrame(sdf) + }) + +#' @rdname select +#' @export +setMethod("select", signature(x = "DataFrame", col = "Column"), + function(x, col, ...) { + jcols <- lapply(list(col, ...), function(c) { + c@jc + }) + sdf <- callJMethod(x@sdf, "select", listToSeq(jcols)) + dataFrame(sdf) + }) + +#' @rdname select +#' @export +setMethod("select", + signature(x = "DataFrame", col = "list"), + function(x, col) { + cols <- lapply(col, function(c) { + if (class(c)== "Column") { + c@jc + } else { + col(c)@jc + } + }) + sdf <- callJMethod(x@sdf, "select", listToSeq(cols)) + dataFrame(sdf) + }) + +#' SelectExpr +#' +#' Select from a DataFrame using a set of SQL expressions. +#' +#' @param x A DataFrame to be selected from. +#' @param expr A string containing a SQL expression +#' @param ... Additional expressions +#' @return A DataFrame +#' @rdname selectExpr +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' selectExpr(df, "col1", "(col2 * 5) as newCol") +#' } +setMethod("selectExpr", + signature(x = "DataFrame", expr = "character"), + function(x, expr, ...) { + exprList <- list(expr, ...) + sdf <- callJMethod(x@sdf, "selectExpr", listToSeq(exprList)) + dataFrame(sdf) + }) + +#' WithColumn +#' +#' Return a new DataFrame with the specified column added. +#' +#' @param x A DataFrame +#' @param colName A string containing the name of the new column. +#' @param col A Column expression. +#' @return A DataFrame with the new column added. +#' @rdname withColumn +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' newDF <- withColumn(df, "newCol", df$col1 * 5) +#' } +setMethod("withColumn", + signature(x = "DataFrame", colName = "character", col = "Column"), + function(x, colName, col) { + select(x, x$"*", alias(col, colName)) + }) + +#' WithColumnRenamed +#' +#' Rename an existing column in a DataFrame. +#' +#' @param x A DataFrame +#' @param existingCol The name of the column you want to change. +#' @param newCol The new column name. +#' @return A DataFrame with the column name changed. +#' @rdname withColumnRenamed +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' newDF <- withColumnRenamed(df, "col1", "newCol1") +#' } +setMethod("withColumnRenamed", + signature(x = "DataFrame", existingCol = "character", newCol = "character"), + function(x, existingCol, newCol) { + cols <- lapply(columns(x), function(c) { + if (c == existingCol) { + alias(col(c), newCol) + } else { + col(c) + } + }) + select(x, cols) + }) + +setClassUnion("characterOrColumn", c("character", "Column")) + +#' SortDF +#' +#' Sort a DataFrame by the specified column(s). +#' +#' @param x A DataFrame to be sorted. +#' @param col Either a Column object or character vector indicating the field to sort on +#' @param ... Additional sorting fields +#' @return A DataFrame where all elements are sorted. +#' @rdname sortDF +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' sortDF(df, df$col1) +#' sortDF(df, "col1") +#' sortDF(df, asc(df$col1), desc(abs(df$col2))) +#' } +setMethod("sortDF", + signature(x = "DataFrame", col = "characterOrColumn"), + function(x, col, ...) { + if (class(col) == "character") { + sdf <- callJMethod(x@sdf, "sort", col, toSeq(...)) + } else if (class(col) == "Column") { + jcols <- lapply(list(col, ...), function(c) { + c@jc + }) + sdf <- callJMethod(x@sdf, "sort", listToSeq(jcols)) + } + dataFrame(sdf) + }) + +#' @rdname sortDF +#' @export +setMethod("orderBy", + signature(x = "DataFrame", col = "characterOrColumn"), + function(x, col) { + sortDF(x, col) + }) + +#' Filter +#' +#' Filter the rows of a DataFrame according to a given condition. +#' +#' @param x A DataFrame to be sorted. +#' @param condition The condition to sort on. This may either be a Column expression +#' or a string containing a SQL statement +#' @return A DataFrame containing only the rows that meet the condition. +#' @rdname filter +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' filter(df, "col1 > 0") +#' filter(df, df$col2 != "abcdefg") +#' } +setMethod("filter", + signature(x = "DataFrame", condition = "characterOrColumn"), + function(x, condition) { + if (class(condition) == "Column") { + condition <- condition@jc + } + sdf <- callJMethod(x@sdf, "filter", condition) + dataFrame(sdf) + }) + +#' @rdname filter +#' @export +setMethod("where", + signature(x = "DataFrame", condition = "characterOrColumn"), + function(x, condition) { + filter(x, condition) + }) + +#' Join +#' +#' Join two DataFrames based on the given join expression. +#' +#' @param x A Spark DataFrame +#' @param y A Spark DataFrame +#' @param joinExpr (Optional) The expression used to perform the join. joinExpr must be a +#' Column expression. If joinExpr is omitted, join() wil perform a Cartesian join +#' @param joinType The type of join to perform. The following join types are available: +#' 'inner', 'outer', 'left_outer', 'right_outer', 'semijoin'. The default joinType is "inner". +#' @return A DataFrame containing the result of the join operation. +#' @rdname join +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' df1 <- jsonFile(sqlCtx, path) +#' df2 <- jsonFile(sqlCtx, path2) +#' join(df1, df2) # Performs a Cartesian +#' join(df1, df2, df1$col1 == df2$col2) # Performs an inner join based on expression +#' join(df1, df2, df1$col1 == df2$col2, "right_outer") +#' } +setMethod("join", + signature(x = "DataFrame", y = "DataFrame"), + function(x, y, joinExpr = NULL, joinType = NULL) { + if (is.null(joinExpr)) { + sdf <- callJMethod(x@sdf, "join", y@sdf) + } else { + if (class(joinExpr) != "Column") stop("joinExpr must be a Column") + if (is.null(joinType)) { + sdf <- callJMethod(x@sdf, "join", y@sdf, joinExpr@jc) + } else { + if (joinType %in% c("inner", "outer", "left_outer", "right_outer", "semijoin")) { + sdf <- callJMethod(x@sdf, "join", y@sdf, joinExpr@jc, joinType) + } else { + stop("joinType must be one of the following types: ", + "'inner', 'outer', 'left_outer', 'right_outer', 'semijoin'") + } + } + } + dataFrame(sdf) + }) + +#' UnionAll +#' +#' Return a new DataFrame containing the union of rows in this DataFrame +#' and another DataFrame. This is equivalent to `UNION ALL` in SQL. +#' +#' @param x A Spark DataFrame +#' @param y A Spark DataFrame +#' @return A DataFrame containing the result of the union. +#' @rdname unionAll +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' df1 <- jsonFile(sqlCtx, path) +#' df2 <- jsonFile(sqlCtx, path2) +#' unioned <- unionAll(df, df2) +#' } +setMethod("unionAll", + signature(x = "DataFrame", y = "DataFrame"), + function(x, y) { + unioned <- callJMethod(x@sdf, "unionAll", y@sdf) + dataFrame(unioned) + }) + +#' Intersect +#' +#' Return a new DataFrame containing rows only in both this DataFrame +#' and another DataFrame. This is equivalent to `INTERSECT` in SQL. +#' +#' @param x A Spark DataFrame +#' @param y A Spark DataFrame +#' @return A DataFrame containing the result of the intersect. +#' @rdname intersect +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' df1 <- jsonFile(sqlCtx, path) +#' df2 <- jsonFile(sqlCtx, path2) +#' intersectDF <- intersect(df, df2) +#' } +setMethod("intersect", + signature(x = "DataFrame", y = "DataFrame"), + function(x, y) { + intersected <- callJMethod(x@sdf, "intersect", y@sdf) + dataFrame(intersected) + }) + +#' Subtract +#' +#' Return a new DataFrame containing rows in this DataFrame +#' but not in another DataFrame. This is equivalent to `EXCEPT` in SQL. +#' +#' @param x A Spark DataFrame +#' @param y A Spark DataFrame +#' @return A DataFrame containing the result of the subtract operation. +#' @rdname subtract +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' df1 <- jsonFile(sqlCtx, path) +#' df2 <- jsonFile(sqlCtx, path2) +#' subtractDF <- subtract(df, df2) +#' } +setMethod("subtract", + signature(x = "DataFrame", y = "DataFrame"), + function(x, y) { + subtracted <- callJMethod(x@sdf, "except", y@sdf) + dataFrame(subtracted) + }) + +#' Save the contents of the DataFrame to a data source +#' +#' The data source is specified by the `source` and a set of options (...). +#' If `source` is not specified, the default data source configured by +#' spark.sql.sources.default will be used. +#' +#' Additionally, mode is used to specify the behavior of the save operation when +#' data already exists in the data source. There are four modes: +#' append: Contents of this DataFrame are expected to be appended to existing data. +#' overwrite: Existing data is expected to be overwritten by the contents of +# this DataFrame. +#' error: An exception is expected to be thrown. +#' ignore: The save operation is expected to not save the contents of the DataFrame +# and to not change the existing data. +#' +#' @param df A SparkSQL DataFrame +#' @param path A name for the table +#' @param source A name for external data source +#' @param mode One of 'append', 'overwrite', 'error', 'ignore' +#' +#' @rdname saveAsTable +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' saveAsTable(df, "myfile") +#' } +setMethod("saveDF", + signature(df = "DataFrame", path = 'character', source = 'character', + mode = 'character'), + function(df, path = NULL, source = NULL, mode = "append", ...){ + if (is.null(source)) { + sqlCtx <- get(".sparkRSQLsc", envir = .sparkREnv) + source <- callJMethod(sqlCtx, "getConf", "spark.sql.sources.default", + "org.apache.spark.sql.parquet") + } + allModes <- c("append", "overwrite", "error", "ignore") + if (!(mode %in% allModes)) { + stop('mode should be one of "append", "overwrite", "error", "ignore"') + } + jmode <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "saveMode", mode) + options <- varargsToEnv(...) + if (!is.null(path)) { + options[['path']] = path + } + callJMethod(df@sdf, "save", source, jmode, options) + }) + + +#' saveAsTable +#' +#' Save the contents of the DataFrame to a data source as a table +#' +#' The data source is specified by the `source` and a set of options (...). +#' If `source` is not specified, the default data source configured by +#' spark.sql.sources.default will be used. +#' +#' Additionally, mode is used to specify the behavior of the save operation when +#' data already exists in the data source. There are four modes: +#' append: Contents of this DataFrame are expected to be appended to existing data. +#' overwrite: Existing data is expected to be overwritten by the contents of +# this DataFrame. +#' error: An exception is expected to be thrown. +#' ignore: The save operation is expected to not save the contents of the DataFrame +# and to not change the existing data. +#' +#' @param df A SparkSQL DataFrame +#' @param tableName A name for the table +#' @param source A name for external data source +#' @param mode One of 'append', 'overwrite', 'error', 'ignore' +#' +#' @rdname saveAsTable +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' saveAsTable(df, "myfile") +#' } +setMethod("saveAsTable", + signature(df = "DataFrame", tableName = 'character', source = 'character', + mode = 'character'), + function(df, tableName, source = NULL, mode="append", ...){ + if (is.null(source)) { + sqlCtx <- get(".sparkRSQLsc", envir = .sparkREnv) + source <- callJMethod(sqlCtx, "getConf", "spark.sql.sources.default", + "org.apache.spark.sql.parquet") + } + allModes <- c("append", "overwrite", "error", "ignore") + if (!(mode %in% allModes)) { + stop('mode should be one of "append", "overwrite", "error", "ignore"') + } + jmode <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "saveMode", mode) + options <- varargsToEnv(...) + callJMethod(df@sdf, "saveAsTable", tableName, source, jmode, options) + }) + diff --git a/R/pkg/R/RDD.R b/R/pkg/R/RDD.R new file mode 100644 index 0000000000000..604ad03c407b9 --- /dev/null +++ b/R/pkg/R/RDD.R @@ -0,0 +1,1539 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# RDD in R implemented in S4 OO system. + +setOldClass("jobj") + +#' @title S4 class that represents an RDD +#' @description RDD can be created using functions like +#' \code{parallelize}, \code{textFile} etc. +#' @rdname RDD +#' @seealso parallelize, textFile +#' +#' @slot env An R environment that stores bookkeeping states of the RDD +#' @slot jrdd Java object reference to the backing JavaRDD +#' to an RDD +#' @export +setClass("RDD", + slots = list(env = "environment", + jrdd = "jobj")) + +setClass("PipelinedRDD", + slots = list(prev = "RDD", + func = "function", + prev_jrdd = "jobj"), + contains = "RDD") + +setMethod("initialize", "RDD", function(.Object, jrdd, serializedMode, + isCached, isCheckpointed) { + # Check that RDD constructor is using the correct version of serializedMode + stopifnot(class(serializedMode) == "character") + stopifnot(serializedMode %in% c("byte", "string", "row")) + # RDD has three serialization types: + # byte: The RDD stores data serialized in R. + # string: The RDD stores data as strings. + # row: The RDD stores the serialized rows of a DataFrame. + + # We use an environment to store mutable states inside an RDD object. + # Note that R's call-by-value semantics makes modifying slots inside an + # object (passed as an argument into a function, such as cache()) difficult: + # i.e. one needs to make a copy of the RDD object and sets the new slot value + # there. + + # The slots are inheritable from superclass. Here, both `env' and `jrdd' are + # inherited from RDD, but only the former is used. + .Object@env <- new.env() + .Object@env$isCached <- isCached + .Object@env$isCheckpointed <- isCheckpointed + .Object@env$serializedMode <- serializedMode + + .Object@jrdd <- jrdd + .Object +}) + +setMethod("initialize", "PipelinedRDD", function(.Object, prev, func, jrdd_val) { + .Object@env <- new.env() + .Object@env$isCached <- FALSE + .Object@env$isCheckpointed <- FALSE + .Object@env$jrdd_val <- jrdd_val + if (!is.null(jrdd_val)) { + # This tracks the serialization mode for jrdd_val + .Object@env$serializedMode <- prev@env$serializedMode + } + + .Object@prev <- prev + + isPipelinable <- function(rdd) { + e <- rdd@env + !(e$isCached || e$isCheckpointed) + } + + if (!inherits(prev, "PipelinedRDD") || !isPipelinable(prev)) { + # This transformation is the first in its stage: + .Object@func <- func + .Object@prev_jrdd <- getJRDD(prev) + .Object@env$prev_serializedMode <- prev@env$serializedMode + # NOTE: We use prev_serializedMode to track the serialization mode of prev_JRDD + # prev_serializedMode is used during the delayed computation of JRDD in getJRDD + } else { + pipelinedFunc <- function(split, iterator) { + func(split, prev@func(split, iterator)) + } + .Object@func <- pipelinedFunc + .Object@prev_jrdd <- prev@prev_jrdd # maintain the pipeline + # Get the serialization mode of the parent RDD + .Object@env$prev_serializedMode <- prev@env$prev_serializedMode + } + + .Object +}) + +#' @rdname RDD +#' @export +#' +#' @param jrdd Java object reference to the backing JavaRDD +#' @param serializedMode Use "byte" if the RDD stores data serialized in R, "string" if the RDD +#' stores strings, and "row" if the RDD stores the rows of a DataFrame +#' @param isCached TRUE if the RDD is cached +#' @param isCheckpointed TRUE if the RDD has been checkpointed +RDD <- function(jrdd, serializedMode = "byte", isCached = FALSE, + isCheckpointed = FALSE) { + new("RDD", jrdd, serializedMode, isCached, isCheckpointed) +} + +PipelinedRDD <- function(prev, func) { + new("PipelinedRDD", prev, func, NULL) +} + +# Return the serialization mode for an RDD. +setGeneric("getSerializedMode", function(rdd, ...) { standardGeneric("getSerializedMode") }) +# For normal RDDs we can directly read the serializedMode +setMethod("getSerializedMode", signature(rdd = "RDD"), function(rdd) rdd@env$serializedMode ) +# For pipelined RDDs if jrdd_val is set then serializedMode should exist +# if not we return the defaultSerialization mode of "byte" as we don't know the serialization +# mode at this point in time. +setMethod("getSerializedMode", signature(rdd = "PipelinedRDD"), + function(rdd) { + if (!is.null(rdd@env$jrdd_val)) { + return(rdd@env$serializedMode) + } else { + return("byte") + } + }) + +# The jrdd accessor function. +setMethod("getJRDD", signature(rdd = "RDD"), function(rdd) rdd@jrdd ) +setMethod("getJRDD", signature(rdd = "PipelinedRDD"), + function(rdd, serializedMode = "byte") { + if (!is.null(rdd@env$jrdd_val)) { + return(rdd@env$jrdd_val) + } + + computeFunc <- function(split, part) { + rdd@func(split, part) + } + + packageNamesArr <- serialize(.sparkREnv[[".packages"]], + connection = NULL) + + broadcastArr <- lapply(ls(.broadcastNames), + function(name) { get(name, .broadcastNames) }) + + serializedFuncArr <- serialize(computeFunc, connection = NULL) + + prev_jrdd <- rdd@prev_jrdd + + if (serializedMode == "string") { + rddRef <- newJObject("org.apache.spark.api.r.StringRRDD", + callJMethod(prev_jrdd, "rdd"), + serializedFuncArr, + rdd@env$prev_serializedMode, + packageNamesArr, + as.character(.sparkREnv[["libname"]]), + broadcastArr, + callJMethod(prev_jrdd, "classTag")) + } else { + rddRef <- newJObject("org.apache.spark.api.r.RRDD", + callJMethod(prev_jrdd, "rdd"), + serializedFuncArr, + rdd@env$prev_serializedMode, + serializedMode, + packageNamesArr, + as.character(.sparkREnv[["libname"]]), + broadcastArr, + callJMethod(prev_jrdd, "classTag")) + } + # Save the serialization flag after we create a RRDD + rdd@env$serializedMode <- serializedMode + rdd@env$jrdd_val <- callJMethod(rddRef, "asJavaRDD") # rddRef$asJavaRDD() + rdd@env$jrdd_val + }) + +setValidity("RDD", + function(object) { + jrdd <- getJRDD(object) + cls <- callJMethod(jrdd, "getClass") + className <- callJMethod(cls, "getName") + if (grep("spark.api.java.*RDD*", className) == 1) { + TRUE + } else { + paste("Invalid RDD class ", className) + } + }) + + +############ Actions and Transformations ############ + +#' Persist an RDD +#' +#' Persist this RDD with the default storage level (MEMORY_ONLY). +#' +#' @param x The RDD to cache +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10, 2L) +#' cache(rdd) +#'} +#' @rdname cache-methods +#' @aliases cache,RDD-method +setMethod("cache", + signature(x = "RDD"), + function(x) { + callJMethod(getJRDD(x), "cache") + x@env$isCached <- TRUE + x + }) + +#' Persist an RDD +#' +#' Persist this RDD with the specified storage level. For details of the +#' supported storage levels, refer to +#' http://spark.apache.org/docs/latest/programming-guide.html#rdd-persistence. +#' +#' @param x The RDD to persist +#' @param newLevel The new storage level to be assigned +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10, 2L) +#' persist(rdd, "MEMORY_AND_DISK") +#'} +#' @rdname persist +#' @aliases persist,RDD-method +setMethod("persist", + signature(x = "RDD", newLevel = "character"), + function(x, newLevel) { + callJMethod(getJRDD(x), "persist", getStorageLevel(newLevel)) + x@env$isCached <- TRUE + x + }) + +#' Unpersist an RDD +#' +#' Mark the RDD as non-persistent, and remove all blocks for it from memory and +#' disk. +#' +#' @param x The RDD to unpersist +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10, 2L) +#' cache(rdd) # rdd@@env$isCached == TRUE +#' unpersist(rdd) # rdd@@env$isCached == FALSE +#'} +#' @rdname unpersist-methods +#' @aliases unpersist,RDD-method +setMethod("unpersist", + signature(x = "RDD"), + function(x) { + callJMethod(getJRDD(x), "unpersist") + x@env$isCached <- FALSE + x + }) + +#' Checkpoint an RDD +#' +#' Mark this RDD for checkpointing. It will be saved to a file inside the +#' checkpoint directory set with setCheckpointDir() and all references to its +#' parent RDDs will be removed. This function must be called before any job has +#' been executed on this RDD. It is strongly recommended that this RDD is +#' persisted in memory, otherwise saving it on a file will require recomputation. +#' +#' @param x The RDD to checkpoint +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' setCheckpointDir(sc, "checkpoints") +#' rdd <- parallelize(sc, 1:10, 2L) +#' checkpoint(rdd) +#'} +#' @rdname checkpoint-methods +#' @aliases checkpoint,RDD-method +setMethod("checkpoint", + signature(x = "RDD"), + function(x) { + jrdd <- getJRDD(x) + callJMethod(jrdd, "checkpoint") + x@env$isCheckpointed <- TRUE + x + }) + +#' Gets the number of partitions of an RDD +#' +#' @param x A RDD. +#' @return the number of partitions of rdd as an integer. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10, 2L) +#' numPartitions(rdd) # 2L +#'} +#' @rdname numPartitions +#' @aliases numPartitions,RDD-method +setMethod("numPartitions", + signature(x = "RDD"), + function(x) { + jrdd <- getJRDD(x) + partitions <- callJMethod(jrdd, "splits") + callJMethod(partitions, "size") + }) + +#' Collect elements of an RDD +#' +#' @description +#' \code{collect} returns a list that contains all of the elements in this RDD. +#' +#' @param x The RDD to collect +#' @param ... Other optional arguments to collect +#' @param flatten FALSE if the list should not flattened +#' @return a list containing elements in the RDD +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10, 2L) +#' collect(rdd) # list from 1 to 10 +#' collectPartition(rdd, 0L) # list from 1 to 5 +#'} +#' @rdname collect-methods +#' @aliases collect,RDD-method +setMethod("collect", + signature(x = "RDD"), + function(x, flatten = TRUE) { + # Assumes a pairwise RDD is backed by a JavaPairRDD. + collected <- callJMethod(getJRDD(x), "collect") + convertJListToRList(collected, flatten, + serializedMode = getSerializedMode(x)) + }) + + +#' @description +#' \code{collectPartition} returns a list that contains all of the elements +#' in the specified partition of the RDD. +#' @param partitionId the partition to collect (starts from 0) +#' @rdname collect-methods +#' @aliases collectPartition,integer,RDD-method +setMethod("collectPartition", + signature(x = "RDD", partitionId = "integer"), + function(x, partitionId) { + jPartitionsList <- callJMethod(getJRDD(x), + "collectPartitions", + as.list(as.integer(partitionId))) + + jList <- jPartitionsList[[1]] + convertJListToRList(jList, flatten = TRUE, + serializedMode = getSerializedMode(x)) + }) + +#' @description +#' \code{collectAsMap} returns a named list as a map that contains all of the elements +#' in a key-value pair RDD. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(list(1, 2), list(3, 4)), 2L) +#' collectAsMap(rdd) # list(`1` = 2, `3` = 4) +#'} +#' @rdname collect-methods +#' @aliases collectAsMap,RDD-method +setMethod("collectAsMap", + signature(x = "RDD"), + function(x) { + pairList <- collect(x) + map <- new.env() + lapply(pairList, function(i) { assign(as.character(i[[1]]), i[[2]], envir = map) }) + as.list(map) + }) + +#' Return the number of elements in the RDD. +#' +#' @param x The RDD to count +#' @return number of elements in the RDD. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' count(rdd) # 10 +#' length(rdd) # Same as count +#'} +#' @rdname count +#' @aliases count,RDD-method +setMethod("count", + signature(x = "RDD"), + function(x) { + countPartition <- function(part) { + as.integer(length(part)) + } + valsRDD <- lapplyPartition(x, countPartition) + vals <- collect(valsRDD) + sum(as.integer(vals)) + }) + +#' Return the number of elements in the RDD +#' @export +#' @rdname count +setMethod("length", + signature(x = "RDD"), + function(x) { + count(x) + }) + +#' Return the count of each unique value in this RDD as a list of +#' (value, count) pairs. +#' +#' Same as countByValue in Spark. +#' +#' @param x The RDD to count +#' @return list of (value, count) pairs, where count is number of each unique +#' value in rdd. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, c(1,2,3,2,1)) +#' countByValue(rdd) # (1,2L), (2,2L), (3,1L) +#'} +#' @rdname countByValue +#' @aliases countByValue,RDD-method +setMethod("countByValue", + signature(x = "RDD"), + function(x) { + ones <- lapply(x, function(item) { list(item, 1L) }) + collect(reduceByKey(ones, `+`, numPartitions(x))) + }) + +#' Apply a function to all elements +#' +#' This function creates a new RDD by applying the given transformation to all +#' elements of the given RDD +#' +#' @param X The RDD to apply the transformation. +#' @param FUN the transformation to apply on each element +#' @return a new RDD created by the transformation. +#' @rdname lapply +#' @aliases lapply +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' multiplyByTwo <- lapply(rdd, function(x) { x * 2 }) +#' collect(multiplyByTwo) # 2,4,6... +#'} +setMethod("lapply", + signature(X = "RDD", FUN = "function"), + function(X, FUN) { + func <- function(split, iterator) { + lapply(iterator, FUN) + } + lapplyPartitionsWithIndex(X, func) + }) + +#' @rdname lapply +#' @aliases map,RDD,function-method +setMethod("map", + signature(X = "RDD", FUN = "function"), + function(X, FUN) { + lapply(X, FUN) + }) + +#' Flatten results after apply a function to all elements +#' +#' This function return a new RDD by first applying a function to all +#' elements of this RDD, and then flattening the results. +#' +#' @param X The RDD to apply the transformation. +#' @param FUN the transformation to apply on each element +#' @return a new RDD created by the transformation. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' multiplyByTwo <- flatMap(rdd, function(x) { list(x*2, x*10) }) +#' collect(multiplyByTwo) # 2,20,4,40,6,60... +#'} +#' @rdname flatMap +#' @aliases flatMap,RDD,function-method +setMethod("flatMap", + signature(X = "RDD", FUN = "function"), + function(X, FUN) { + partitionFunc <- function(part) { + unlist( + lapply(part, FUN), + recursive = F + ) + } + lapplyPartition(X, partitionFunc) + }) + +#' Apply a function to each partition of an RDD +#' +#' Return a new RDD by applying a function to each partition of this RDD. +#' +#' @param X The RDD to apply the transformation. +#' @param FUN the transformation to apply on each partition. +#' @return a new RDD created by the transformation. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' partitionSum <- lapplyPartition(rdd, function(part) { Reduce("+", part) }) +#' collect(partitionSum) # 15, 40 +#'} +#' @rdname lapplyPartition +#' @aliases lapplyPartition,RDD,function-method +setMethod("lapplyPartition", + signature(X = "RDD", FUN = "function"), + function(X, FUN) { + lapplyPartitionsWithIndex(X, function(s, part) { FUN(part) }) + }) + +#' mapPartitions is the same as lapplyPartition. +#' +#' @rdname lapplyPartition +#' @aliases mapPartitions,RDD,function-method +setMethod("mapPartitions", + signature(X = "RDD", FUN = "function"), + function(X, FUN) { + lapplyPartition(X, FUN) + }) + +#' Return a new RDD by applying a function to each partition of this RDD, while +#' tracking the index of the original partition. +#' +#' @param X The RDD to apply the transformation. +#' @param FUN the transformation to apply on each partition; takes the partition +#' index and a list of elements in the particular partition. +#' @return a new RDD created by the transformation. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10, 5L) +#' prod <- lapplyPartitionsWithIndex(rdd, function(split, part) { +#' split * Reduce("+", part) }) +#' collect(prod, flatten = FALSE) # 0, 7, 22, 45, 76 +#'} +#' @rdname lapplyPartitionsWithIndex +#' @aliases lapplyPartitionsWithIndex,RDD,function-method +setMethod("lapplyPartitionsWithIndex", + signature(X = "RDD", FUN = "function"), + function(X, FUN) { + FUN <- cleanClosure(FUN) + closureCapturingFunc <- function(split, part) { + FUN(split, part) + } + PipelinedRDD(X, closureCapturingFunc) + }) + +#' @rdname lapplyPartitionsWithIndex +#' @aliases mapPartitionsWithIndex,RDD,function-method +setMethod("mapPartitionsWithIndex", + signature(X = "RDD", FUN = "function"), + function(X, FUN) { + lapplyPartitionsWithIndex(X, FUN) + }) + +#' This function returns a new RDD containing only the elements that satisfy +#' a predicate (i.e. returning TRUE in a given logical function). +#' The same as `filter()' in Spark. +#' +#' @param x The RDD to be filtered. +#' @param f A unary predicate function. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' unlist(collect(filterRDD(rdd, function (x) { x < 3 }))) # c(1, 2) +#'} +#' @rdname filterRDD +#' @aliases filterRDD,RDD,function-method +setMethod("filterRDD", + signature(x = "RDD", f = "function"), + function(x, f) { + filter.func <- function(part) { + Filter(f, part) + } + lapplyPartition(x, filter.func) + }) + +#' @rdname filterRDD +#' @aliases Filter +setMethod("Filter", + signature(f = "function", x = "RDD"), + function(f, x) { + filterRDD(x, f) + }) + +#' Reduce across elements of an RDD. +#' +#' This function reduces the elements of this RDD using the +#' specified commutative and associative binary operator. +#' +#' @param x The RDD to reduce +#' @param func Commutative and associative function to apply on elements +#' of the RDD. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' reduce(rdd, "+") # 55 +#'} +#' @rdname reduce +#' @aliases reduce,RDD,ANY-method +setMethod("reduce", + signature(x = "RDD", func = "ANY"), + function(x, func) { + + reducePartition <- function(part) { + Reduce(func, part) + } + + partitionList <- collect(lapplyPartition(x, reducePartition), + flatten = FALSE) + Reduce(func, partitionList) + }) + +#' Get the maximum element of an RDD. +#' +#' @param x The RDD to get the maximum element from +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' maximum(rdd) # 10 +#'} +#' @rdname maximum +#' @aliases maximum,RDD +setMethod("maximum", + signature(x = "RDD"), + function(x) { + reduce(x, max) + }) + +#' Get the minimum element of an RDD. +#' +#' @param x The RDD to get the minimum element from +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' minimum(rdd) # 1 +#'} +#' @rdname minimum +#' @aliases minimum,RDD +setMethod("minimum", + signature(x = "RDD"), + function(x) { + reduce(x, min) + }) + +#' Add up the elements in an RDD. +#' +#' @param x The RDD to add up the elements in +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' sumRDD(rdd) # 55 +#'} +#' @rdname sumRDD +#' @aliases sumRDD,RDD +setMethod("sumRDD", + signature(x = "RDD"), + function(x) { + reduce(x, "+") + }) + +#' Applies a function to all elements in an RDD, and force evaluation. +#' +#' @param x The RDD to apply the function +#' @param func The function to be applied. +#' @return invisible NULL. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' foreach(rdd, function(x) { save(x, file=...) }) +#'} +#' @rdname foreach +#' @aliases foreach,RDD,function-method +setMethod("foreach", + signature(x = "RDD", func = "function"), + function(x, func) { + partition.func <- function(x) { + lapply(x, func) + NULL + } + invisible(collect(mapPartitions(x, partition.func))) + }) + +#' Applies a function to each partition in an RDD, and force evaluation. +#' +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' foreachPartition(rdd, function(part) { save(part, file=...); NULL }) +#'} +#' @rdname foreach +#' @aliases foreachPartition,RDD,function-method +setMethod("foreachPartition", + signature(x = "RDD", func = "function"), + function(x, func) { + invisible(collect(mapPartitions(x, func))) + }) + +#' Take elements from an RDD. +#' +#' This function takes the first NUM elements in the RDD and +#' returns them in a list. +#' +#' @param x The RDD to take elements from +#' @param num Number of elements to take +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' take(rdd, 2L) # list(1, 2) +#'} +#' @rdname take +#' @aliases take,RDD,numeric-method +setMethod("take", + signature(x = "RDD", num = "numeric"), + function(x, num) { + resList <- list() + index <- -1 + jrdd <- getJRDD(x) + numPartitions <- numPartitions(x) + + # TODO(shivaram): Collect more than one partition based on size + # estimates similar to the scala version of `take`. + while (TRUE) { + index <- index + 1 + + if (length(resList) >= num || index >= numPartitions) + break + + # a JList of byte arrays + partitionArr <- callJMethod(jrdd, "collectPartitions", as.list(as.integer(index))) + partition <- partitionArr[[1]] + + size <- num - length(resList) + # elems is capped to have at most `size` elements + elems <- convertJListToRList(partition, + flatten = TRUE, + logicalUpperBound = size, + serializedMode = getSerializedMode(x)) + # TODO: Check if this append is O(n^2)? + resList <- append(resList, elems) + } + resList + }) + +#' First +#' +#' Return the first element of an RDD +#' +#' @rdname first +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' first(rdd) +#' } +setMethod("first", + signature(x = "RDD"), + function(x) { + take(x, 1)[[1]] + }) + +#' Removes the duplicates from RDD. +#' +#' This function returns a new RDD containing the distinct elements in the +#' given RDD. The same as `distinct()' in Spark. +#' +#' @param x The RDD to remove duplicates from. +#' @param numPartitions Number of partitions to create. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, c(1,2,2,3,3,3)) +#' sort(unlist(collect(distinct(rdd)))) # c(1, 2, 3) +#'} +#' @rdname distinct +#' @aliases distinct,RDD-method +setMethod("distinct", + signature(x = "RDD"), + function(x, numPartitions = SparkR::numPartitions(x)) { + identical.mapped <- lapply(x, function(x) { list(x, NULL) }) + reduced <- reduceByKey(identical.mapped, + function(x, y) { x }, + numPartitions) + resRDD <- lapply(reduced, function(x) { x[[1]] }) + resRDD + }) + +#' Return an RDD that is a sampled subset of the given RDD. +#' +#' The same as `sample()' in Spark. (We rename it due to signature +#' inconsistencies with the `sample()' function in R's base package.) +#' +#' @param x The RDD to sample elements from +#' @param withReplacement Sampling with replacement or not +#' @param fraction The (rough) sample target fraction +#' @param seed Randomness seed value +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) # ensure each num is in its own split +#' collect(sampleRDD(rdd, FALSE, 0.5, 1618L)) # ~5 distinct elements +#' collect(sampleRDD(rdd, TRUE, 0.5, 9L)) # ~5 elements possibly with duplicates +#'} +#' @rdname sampleRDD +#' @aliases sampleRDD,RDD +setMethod("sampleRDD", + signature(x = "RDD", withReplacement = "logical", + fraction = "numeric", seed = "integer"), + function(x, withReplacement, fraction, seed) { + + # The sampler: takes a partition and returns its sampled version. + samplingFunc <- function(split, part) { + set.seed(seed) + res <- vector("list", length(part)) + len <- 0 + + # Discards some random values to ensure each partition has a + # different random seed. + runif(split) + + for (elem in part) { + if (withReplacement) { + count <- rpois(1, fraction) + if (count > 0) { + res[(len + 1):(len + count)] <- rep(list(elem), count) + len <- len + count + } + } else { + if (runif(1) < fraction) { + len <- len + 1 + res[[len]] <- elem + } + } + } + + # TODO(zongheng): look into the performance of the current + # implementation. Look into some iterator package? Note that + # Scala avoids many calls to creating an empty list and PySpark + # similarly achieves this using `yield'. + if (len > 0) + res[1:len] + else + list() + } + + lapplyPartitionsWithIndex(x, samplingFunc) + }) + +#' Return a list of the elements that are a sampled subset of the given RDD. +#' +#' @param x The RDD to sample elements from +#' @param withReplacement Sampling with replacement or not +#' @param num Number of elements to return +#' @param seed Randomness seed value +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:100) +#' # exactly 5 elements sampled, which may not be distinct +#' takeSample(rdd, TRUE, 5L, 1618L) +#' # exactly 5 distinct elements sampled +#' takeSample(rdd, FALSE, 5L, 16181618L) +#'} +#' @rdname takeSample +#' @aliases takeSample,RDD +setMethod("takeSample", signature(x = "RDD", withReplacement = "logical", + num = "integer", seed = "integer"), + function(x, withReplacement, num, seed) { + # This function is ported from RDD.scala. + fraction <- 0.0 + total <- 0 + multiplier <- 3.0 + initialCount <- count(x) + maxSelected <- 0 + MAXINT <- .Machine$integer.max + + if (num < 0) + stop(paste("Negative number of elements requested")) + + if (initialCount > MAXINT - 1) { + maxSelected <- MAXINT - 1 + } else { + maxSelected <- initialCount + } + + if (num > initialCount && !withReplacement) { + total <- maxSelected + fraction <- multiplier * (maxSelected + 1) / initialCount + } else { + total <- num + fraction <- multiplier * (num + 1) / initialCount + } + + set.seed(seed) + samples <- collect(sampleRDD(x, withReplacement, fraction, + as.integer(ceiling(runif(1, + -MAXINT, + MAXINT))))) + # If the first sample didn't turn out large enough, keep trying to + # take samples; this shouldn't happen often because we use a big + # multiplier for thei initial size + while (length(samples) < total) + samples <- collect(sampleRDD(x, withReplacement, fraction, + as.integer(ceiling(runif(1, + -MAXINT, + MAXINT))))) + + # TODO(zongheng): investigate if this call is an in-place shuffle? + sample(samples)[1:total] + }) + +#' Creates tuples of the elements in this RDD by applying a function. +#' +#' @param x The RDD. +#' @param func The function to be applied. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(1, 2, 3)) +#' collect(keyBy(rdd, function(x) { x*x })) # list(list(1, 1), list(4, 2), list(9, 3)) +#'} +#' @rdname keyBy +#' @aliases keyBy,RDD +setMethod("keyBy", + signature(x = "RDD", func = "function"), + function(x, func) { + apply.func <- function(x) { + list(func(x), x) + } + lapply(x, apply.func) + }) + +#' Return a new RDD that has exactly numPartitions partitions. +#' Can increase or decrease the level of parallelism in this RDD. Internally, +#' this uses a shuffle to redistribute data. +#' If you are decreasing the number of partitions in this RDD, consider using +#' coalesce, which can avoid performing a shuffle. +#' +#' @param x The RDD. +#' @param numPartitions Number of partitions to create. +#' @seealso coalesce +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(1, 2, 3, 4, 5, 6, 7), 4L) +#' numPartitions(rdd) # 4 +#' numPartitions(repartition(rdd, 2L)) # 2 +#'} +#' @rdname repartition +#' @aliases repartition,RDD +setMethod("repartition", + signature(x = "RDD", numPartitions = "numeric"), + function(x, numPartitions) { + coalesce(x, numToInt(numPartitions), TRUE) + }) + +#' Return a new RDD that is reduced into numPartitions partitions. +#' +#' @param x The RDD. +#' @param numPartitions Number of partitions to create. +#' @seealso repartition +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(1, 2, 3, 4, 5), 3L) +#' numPartitions(rdd) # 3 +#' numPartitions(coalesce(rdd, 1L)) # 1 +#'} +#' @rdname coalesce +#' @aliases coalesce,RDD +setMethod("coalesce", + signature(x = "RDD", numPartitions = "numeric"), + function(x, numPartitions, shuffle = FALSE) { + numPartitions <- numToInt(numPartitions) + if (shuffle || numPartitions > SparkR::numPartitions(x)) { + func <- function(s, part) { + set.seed(s) # split as seed + start <- as.integer(sample(numPartitions, 1) - 1) + lapply(seq_along(part), + function(i) { + pos <- (start + i) %% numPartitions + list(pos, part[[i]]) + }) + } + shuffled <- lapplyPartitionsWithIndex(x, func) + repartitioned <- partitionBy(shuffled, numPartitions) + values(repartitioned) + } else { + jrdd <- callJMethod(getJRDD(x), "coalesce", numPartitions, shuffle) + RDD(jrdd) + } + }) + +#' Save this RDD as a SequenceFile of serialized objects. +#' +#' @param x The RDD to save +#' @param path The directory where the file is saved +#' @seealso objectFile +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:3) +#' saveAsObjectFile(rdd, "/tmp/sparkR-tmp") +#'} +#' @rdname saveAsObjectFile +#' @aliases saveAsObjectFile,RDD +setMethod("saveAsObjectFile", + signature(x = "RDD", path = "character"), + function(x, path) { + # If serializedMode == "string" we need to serialize the data before saving it since + # objectFile() assumes serializedMode == "byte". + if (getSerializedMode(x) != "byte") { + x <- serializeToBytes(x) + } + # Return nothing + invisible(callJMethod(getJRDD(x), "saveAsObjectFile", path)) + }) + +#' Save this RDD as a text file, using string representations of elements. +#' +#' @param x The RDD to save +#' @param path The directory where the splits of the text file are saved +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:3) +#' saveAsTextFile(rdd, "/tmp/sparkR-tmp") +#'} +#' @rdname saveAsTextFile +#' @aliases saveAsTextFile,RDD +setMethod("saveAsTextFile", + signature(x = "RDD", path = "character"), + function(x, path) { + func <- function(str) { + toString(str) + } + stringRdd <- lapply(x, func) + # Return nothing + invisible( + callJMethod(getJRDD(stringRdd, serializedMode = "string"), "saveAsTextFile", path)) + }) + +#' Sort an RDD by the given key function. +#' +#' @param x An RDD to be sorted. +#' @param func A function used to compute the sort key for each element. +#' @param ascending A flag to indicate whether the sorting is ascending or descending. +#' @param numPartitions Number of partitions to create. +#' @return An RDD where all elements are sorted. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(3, 2, 1)) +#' collect(sortBy(rdd, function(x) { x })) # list (1, 2, 3) +#'} +#' @rdname sortBy +#' @aliases sortBy,RDD,RDD-method +setMethod("sortBy", + signature(x = "RDD", func = "function"), + function(x, func, ascending = TRUE, numPartitions = SparkR::numPartitions(x)) { + values(sortByKey(keyBy(x, func), ascending, numPartitions)) + }) + +# Helper function to get first N elements from an RDD in the specified order. +# Param: +# x An RDD. +# num Number of elements to return. +# ascending A flag to indicate whether the sorting is ascending or descending. +# Return: +# A list of the first N elements from the RDD in the specified order. +# +takeOrderedElem <- function(x, num, ascending = TRUE) { + if (num <= 0L) { + return(list()) + } + + partitionFunc <- function(part) { + if (num < length(part)) { + # R limitation: order works only on primitive types! + ord <- order(unlist(part, recursive = FALSE), decreasing = !ascending) + list(part[ord[1:num]]) + } else { + list(part) + } + } + + reduceFunc <- function(elems, part) { + newElems <- append(elems, part) + # R limitation: order works only on primitive types! + ord <- order(unlist(newElems, recursive = FALSE), decreasing = !ascending) + newElems[ord[1:num]] + } + + newRdd <- mapPartitions(x, partitionFunc) + reduce(newRdd, reduceFunc) +} + +#' Returns the first N elements from an RDD in ascending order. +#' +#' @param x An RDD. +#' @param num Number of elements to return. +#' @return The first N elements from the RDD in ascending order. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(10, 1, 2, 9, 3, 4, 5, 6, 7)) +#' takeOrdered(rdd, 6L) # list(1, 2, 3, 4, 5, 6) +#'} +#' @rdname takeOrdered +#' @aliases takeOrdered,RDD,RDD-method +setMethod("takeOrdered", + signature(x = "RDD", num = "integer"), + function(x, num) { + takeOrderedElem(x, num) + }) + +#' Returns the top N elements from an RDD. +#' +#' @param x An RDD. +#' @param num Number of elements to return. +#' @return The top N elements from the RDD. +#' @rdname top +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(10, 1, 2, 9, 3, 4, 5, 6, 7)) +#' top(rdd, 6L) # list(10, 9, 7, 6, 5, 4) +#'} +#' @rdname top +#' @aliases top,RDD,RDD-method +setMethod("top", + signature(x = "RDD", num = "integer"), + function(x, num) { + takeOrderedElem(x, num, FALSE) + }) + +#' Fold an RDD using a given associative function and a neutral "zero value". +#' +#' Aggregate the elements of each partition, and then the results for all the +#' partitions, using a given associative function and a neutral "zero value". +#' +#' @param x An RDD. +#' @param zeroValue A neutral "zero value". +#' @param op An associative function for the folding operation. +#' @return The folding result. +#' @rdname fold +#' @seealso reduce +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(1, 2, 3, 4, 5)) +#' fold(rdd, 0, "+") # 15 +#'} +#' @rdname fold +#' @aliases fold,RDD,RDD-method +setMethod("fold", + signature(x = "RDD", zeroValue = "ANY", op = "ANY"), + function(x, zeroValue, op) { + aggregateRDD(x, zeroValue, op, op) + }) + +#' Aggregate an RDD using the given combine functions and a neutral "zero value". +#' +#' Aggregate the elements of each partition, and then the results for all the +#' partitions, using given combine functions and a neutral "zero value". +#' +#' @param x An RDD. +#' @param zeroValue A neutral "zero value". +#' @param seqOp A function to aggregate the RDD elements. It may return a different +#' result type from the type of the RDD elements. +#' @param combOp A function to aggregate results of seqOp. +#' @return The aggregation result. +#' @rdname aggregateRDD +#' @seealso reduce +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(1, 2, 3, 4)) +#' zeroValue <- list(0, 0) +#' seqOp <- function(x, y) { list(x[[1]] + y, x[[2]] + 1) } +#' combOp <- function(x, y) { list(x[[1]] + y[[1]], x[[2]] + y[[2]]) } +#' aggregateRDD(rdd, zeroValue, seqOp, combOp) # list(10, 4) +#'} +#' @rdname aggregateRDD +#' @aliases aggregateRDD,RDD,RDD-method +setMethod("aggregateRDD", + signature(x = "RDD", zeroValue = "ANY", seqOp = "ANY", combOp = "ANY"), + function(x, zeroValue, seqOp, combOp) { + partitionFunc <- function(part) { + Reduce(seqOp, part, zeroValue) + } + + partitionList <- collect(lapplyPartition(x, partitionFunc), + flatten = FALSE) + Reduce(combOp, partitionList, zeroValue) + }) + +#' Pipes elements to a forked external process. +#' +#' The same as 'pipe()' in Spark. +#' +#' @param x The RDD whose elements are piped to the forked external process. +#' @param command The command to fork an external process. +#' @param env A named list to set environment variables of the external process. +#' @return A new RDD created by piping all elements to a forked external process. +#' @rdname pipeRDD +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' collect(pipeRDD(rdd, "more") +#' Output: c("1", "2", ..., "10") +#'} +#' @rdname pipeRDD +#' @aliases pipeRDD,RDD,character-method +setMethod("pipeRDD", + signature(x = "RDD", command = "character"), + function(x, command, env = list()) { + func <- function(part) { + trim.trailing.func <- function(x) { + sub("[\r\n]*$", "", toString(x)) + } + input <- unlist(lapply(part, trim.trailing.func)) + res <- system2(command, stdout = TRUE, input = input, env = env) + lapply(res, trim.trailing.func) + } + lapplyPartition(x, func) + }) + +# TODO: Consider caching the name in the RDD's environment +#' Return an RDD's name. +#' +#' @param x The RDD whose name is returned. +#' @rdname name +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(1,2,3)) +#' name(rdd) # NULL (if not set before) +#'} +#' @rdname name +#' @aliases name,RDD +setMethod("name", + signature(x = "RDD"), + function(x) { + callJMethod(getJRDD(x), "name") + }) + +#' Set an RDD's name. +#' +#' @param x The RDD whose name is to be set. +#' @param name The RDD name to be set. +#' @return a new RDD renamed. +#' @rdname setName +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(1,2,3)) +#' setName(rdd, "myRDD") +#' name(rdd) # "myRDD" +#'} +#' @rdname setName +#' @aliases setName,RDD +setMethod("setName", + signature(x = "RDD", name = "character"), + function(x, name) { + callJMethod(getJRDD(x), "setName", name) + x + }) + +#' Zip an RDD with generated unique Long IDs. +#' +#' Items in the kth partition will get ids k, n+k, 2*n+k, ..., where +#' n is the number of partitions. So there may exist gaps, but this +#' method won't trigger a spark job, which is different from +#' zipWithIndex. +#' +#' @param x An RDD to be zipped. +#' @return An RDD with zipped items. +#' @seealso zipWithIndex +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list("a", "b", "c", "d", "e"), 3L) +#' collect(zipWithUniqueId(rdd)) +#' # list(list("a", 0), list("b", 3), list("c", 1), list("d", 4), list("e", 2)) +#'} +#' @rdname zipWithUniqueId +#' @aliases zipWithUniqueId,RDD +setMethod("zipWithUniqueId", + signature(x = "RDD"), + function(x) { + n <- numPartitions(x) + + partitionFunc <- function(split, part) { + mapply( + function(item, index) { + list(item, (index - 1) * n + split) + }, + part, + seq_along(part), + SIMPLIFY = FALSE) + } + + lapplyPartitionsWithIndex(x, partitionFunc) + }) + +#' Zip an RDD with its element indices. +#' +#' The ordering is first based on the partition index and then the +#' ordering of items within each partition. So the first item in +#' the first partition gets index 0, and the last item in the last +#' partition receives the largest index. +#' +#' This method needs to trigger a Spark job when this RDD contains +#' more than one partition. +#' +#' @param x An RDD to be zipped. +#' @return An RDD with zipped items. +#' @seealso zipWithUniqueId +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list("a", "b", "c", "d", "e"), 3L) +#' collect(zipWithIndex(rdd)) +#' # list(list("a", 0), list("b", 1), list("c", 2), list("d", 3), list("e", 4)) +#'} +#' @rdname zipWithIndex +#' @aliases zipWithIndex,RDD +setMethod("zipWithIndex", + signature(x = "RDD"), + function(x) { + n <- numPartitions(x) + if (n > 1) { + nums <- collect(lapplyPartition(x, + function(part) { + list(length(part)) + })) + startIndices <- Reduce("+", nums, accumulate = TRUE) + } + + partitionFunc <- function(split, part) { + if (split == 0) { + startIndex <- 0 + } else { + startIndex <- startIndices[[split]] + } + + mapply( + function(item, index) { + list(item, index - 1 + startIndex) + }, + part, + seq_along(part), + SIMPLIFY = FALSE) + } + + lapplyPartitionsWithIndex(x, partitionFunc) + }) + +#' Coalesce all elements within each partition of an RDD into a list. +#' +#' @param x An RDD. +#' @return An RDD created by coalescing all elements within +#' each partition into a list. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, as.list(1:4), 2L) +#' collect(glom(rdd)) +#' # list(list(1, 2), list(3, 4)) +#'} +#' @rdname glom +#' @aliases glom,RDD +setMethod("glom", + signature(x = "RDD"), + function(x) { + partitionFunc <- function(part) { + list(part) + } + + lapplyPartition(x, partitionFunc) + }) + +############ Binary Functions ############# + +#' Return the union RDD of two RDDs. +#' The same as union() in Spark. +#' +#' @param x An RDD. +#' @param y An RDD. +#' @return a new RDD created by performing the simple union (witout removing +#' duplicates) of two input RDDs. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:3) +#' unionRDD(rdd, rdd) # 1, 2, 3, 1, 2, 3 +#'} +#' @rdname unionRDD +#' @aliases unionRDD,RDD,RDD-method +setMethod("unionRDD", + signature(x = "RDD", y = "RDD"), + function(x, y) { + if (getSerializedMode(x) == getSerializedMode(y)) { + jrdd <- callJMethod(getJRDD(x), "union", getJRDD(y)) + union.rdd <- RDD(jrdd, getSerializedMode(x)) + } else { + # One of the RDDs is not serialized, we need to serialize it first. + if (getSerializedMode(x) != "byte") x <- serializeToBytes(x) + if (getSerializedMode(y) != "byte") y <- serializeToBytes(y) + jrdd <- callJMethod(getJRDD(x), "union", getJRDD(y)) + union.rdd <- RDD(jrdd, "byte") + } + union.rdd + }) + +#' Zip an RDD with another RDD. +#' +#' Zips this RDD with another one, returning key-value pairs with the +#' first element in each RDD second element in each RDD, etc. Assumes +#' that the two RDDs have the same number of partitions and the same +#' number of elements in each partition (e.g. one was made through +#' a map on the other). +#' +#' @param x An RDD to be zipped. +#' @param other Another RDD to be zipped. +#' @return An RDD zipped from the two RDDs. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd1 <- parallelize(sc, 0:4) +#' rdd2 <- parallelize(sc, 1000:1004) +#' collect(zipRDD(rdd1, rdd2)) +#' # list(list(0, 1000), list(1, 1001), list(2, 1002), list(3, 1003), list(4, 1004)) +#'} +#' @rdname zipRDD +#' @aliases zipRDD,RDD +setMethod("zipRDD", + signature(x = "RDD", other = "RDD"), + function(x, other) { + n1 <- numPartitions(x) + n2 <- numPartitions(other) + if (n1 != n2) { + stop("Can only zip RDDs which have the same number of partitions.") + } + + if (getSerializedMode(x) != getSerializedMode(other) || + getSerializedMode(x) == "byte") { + # Append the number of elements in each partition to that partition so that we can later + # check if corresponding partitions of both RDDs have the same number of elements. + # + # Note that this appending also serves the purpose of reserialization, because even if + # any RDD is serialized, we need to reserialize it to make sure its partitions are encoded + # as a single byte array. For example, partitions of an RDD generated from partitionBy() + # may be encoded as multiple byte arrays. + appendLength <- function(part) { + part[[length(part) + 1]] <- length(part) + 1 + part + } + x <- lapplyPartition(x, appendLength) + other <- lapplyPartition(other, appendLength) + } + + zippedJRDD <- callJMethod(getJRDD(x), "zip", getJRDD(other)) + # The zippedRDD's elements are of scala Tuple2 type. The serialized + # flag Here is used for the elements inside the tuples. + serializerMode <- getSerializedMode(x) + zippedRDD <- RDD(zippedJRDD, serializerMode) + + partitionFunc <- function(split, part) { + len <- length(part) + if (len > 0) { + if (serializerMode == "byte") { + lengthOfValues <- part[[len]] + lengthOfKeys <- part[[len - lengthOfValues]] + stopifnot(len == lengthOfKeys + lengthOfValues) + + # check if corresponding partitions of both RDDs have the same number of elements. + if (lengthOfKeys != lengthOfValues) { + stop("Can only zip RDDs with same number of elements in each pair of corresponding partitions.") + } + + if (lengthOfKeys > 1) { + keys <- part[1 : (lengthOfKeys - 1)] + values <- part[(lengthOfKeys + 1) : (len - 1)] + } else { + keys <- list() + values <- list() + } + } else { + # Keys, values must have same length here, because this has + # been validated inside the JavaRDD.zip() function. + keys <- part[c(TRUE, FALSE)] + values <- part[c(FALSE, TRUE)] + } + mapply( + function(k, v) { + list(k, v) + }, + keys, + values, + SIMPLIFY = FALSE, + USE.NAMES = FALSE) + } else { + part + } + } + + PipelinedRDD(zippedRDD, partitionFunc) + }) diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R new file mode 100644 index 0000000000000..930ada22f4c38 --- /dev/null +++ b/R/pkg/R/SQLContext.R @@ -0,0 +1,520 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# SQLcontext.R: SQLContext-driven functions + +#' infer the SQL type +infer_type <- function(x) { + if (is.null(x)) { + stop("can not infer type from NULL") + } + + # class of POSIXlt is c("POSIXlt" "POSIXt") + type <- switch(class(x)[[1]], + integer = "integer", + character = "string", + logical = "boolean", + double = "double", + numeric = "double", + raw = "binary", + list = "array", + environment = "map", + Date = "date", + POSIXlt = "timestamp", + POSIXct = "timestamp", + stop(paste("Unsupported type for DataFrame:", class(x)))) + + if (type == "map") { + stopifnot(length(x) > 0) + key <- ls(x)[[1]] + list(type = "map", + keyType = "string", + valueType = infer_type(get(key, x)), + valueContainsNull = TRUE) + } else if (type == "array") { + stopifnot(length(x) > 0) + names <- names(x) + if (is.null(names)) { + list(type = "array", elementType = infer_type(x[[1]]), containsNull = TRUE) + } else { + # StructType + types <- lapply(x, infer_type) + fields <- lapply(1:length(x), function(i) { + list(name = names[[i]], type = types[[i]], nullable = TRUE) + }) + list(type = "struct", fields = fields) + } + } else if (length(x) > 1) { + list(type = "array", elementType = type, containsNull = TRUE) + } else { + type + } +} + +#' dump the schema into JSON string +tojson <- function(x) { + if (is.list(x)) { + names <- names(x) + if (!is.null(names)) { + items <- lapply(names, function(n) { + safe_n <- gsub('"', '\\"', n) + paste(tojson(safe_n), ':', tojson(x[[n]]), sep = '') + }) + d <- paste(items, collapse = ', ') + paste('{', d, '}', sep = '') + } else { + l <- paste(lapply(x, tojson), collapse = ', ') + paste('[', l, ']', sep = '') + } + } else if (is.character(x)) { + paste('"', x, '"', sep = '') + } else if (is.logical(x)) { + if (x) "true" else "false" + } else { + stop(paste("unexpected type:", class(x))) + } +} + +#' Create a DataFrame from an RDD +#' +#' Converts an RDD to a DataFrame by infer the types. +#' +#' @param sqlCtx A SQLContext +#' @param data An RDD or list or data.frame +#' @param schema a list of column names or named list (StructType), optional +#' @return an DataFrame +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' rdd <- lapply(parallelize(sc, 1:10), function(x) list(a=x, b=as.character(x))) +#' df <- createDataFrame(sqlCtx, rdd) +#' } + +# TODO(davies): support sampling and infer type from NA +createDataFrame <- function(sqlCtx, data, schema = NULL, samplingRatio = 1.0) { + if (is.data.frame(data)) { + # get the names of columns, they will be put into RDD + schema <- names(data) + n <- nrow(data) + m <- ncol(data) + # get rid of factor type + dropFactor <- function(x) { + if (is.factor(x)) { + as.character(x) + } else { + x + } + } + data <- lapply(1:n, function(i) { + lapply(1:m, function(j) { dropFactor(data[i,j]) }) + }) + } + if (is.list(data)) { + sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sqlCtx) + rdd <- parallelize(sc, data) + } else if (inherits(data, "RDD")) { + rdd <- data + } else { + stop(paste("unexpected type:", class(data))) + } + + if (is.null(schema) || is.null(names(schema))) { + row <- first(rdd) + names <- if (is.null(schema)) { + names(row) + } else { + as.list(schema) + } + if (is.null(names)) { + names <- lapply(1:length(row), function(x) { + paste("_", as.character(x), sep = "") + }) + } + + # SPAKR-SQL does not support '.' in column name, so replace it with '_' + # TODO(davies): remove this once SPARK-2775 is fixed + names <- lapply(names, function(n) { + nn <- gsub("[.]", "_", n) + if (nn != n) { + warning(paste("Use", nn, "instead of", n, " as column name")) + } + nn + }) + + types <- lapply(row, infer_type) + fields <- lapply(1:length(row), function(i) { + list(name = names[[i]], type = types[[i]], nullable = TRUE) + }) + schema <- list(type = "struct", fields = fields) + } + + stopifnot(class(schema) == "list") + stopifnot(schema$type == "struct") + stopifnot(class(schema$fields) == "list") + schemaString <- tojson(schema) + + jrdd <- getJRDD(lapply(rdd, function(x) x), "row") + srdd <- callJMethod(jrdd, "rdd") + sdf <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "createDF", + srdd, schemaString, sqlCtx) + dataFrame(sdf) +} + +#' toDF +#' +#' Converts an RDD to a DataFrame by infer the types. +#' +#' @param x An RDD +#' +#' @rdname DataFrame +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' rdd <- lapply(parallelize(sc, 1:10), function(x) list(a=x, b=as.character(x))) +#' df <- toDF(rdd) +#' } + +setGeneric("toDF", function(x, ...) { standardGeneric("toDF") }) + +setMethod("toDF", signature(x = "RDD"), + function(x, ...) { + sqlCtx <- if (exists(".sparkRHivesc", envir = .sparkREnv)) { + get(".sparkRHivesc", envir = .sparkREnv) + } else if (exists(".sparkRSQLsc", envir = .sparkREnv)) { + get(".sparkRSQLsc", envir = .sparkREnv) + } else { + stop("no SQL context available") + } + createDataFrame(sqlCtx, x, ...) + }) + +#' Create a DataFrame from a JSON file. +#' +#' Loads a JSON file (one object per line), returning the result as a DataFrame +#' It goes through the entire dataset once to determine the schema. +#' +#' @param sqlCtx SQLContext to use +#' @param path Path of file to read. A vector of multiple paths is allowed. +#' @return DataFrame +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' } + +jsonFile <- function(sqlCtx, path) { + # Allow the user to have a more flexible definiton of the text file path + path <- normalizePath(path) + # Convert a string vector of paths to a string containing comma separated paths + path <- paste(path, collapse = ",") + sdf <- callJMethod(sqlCtx, "jsonFile", path) + dataFrame(sdf) +} + + +#' JSON RDD +#' +#' Loads an RDD storing one JSON object per string as a DataFrame. +#' +#' @param sqlCtx SQLContext to use +#' @param rdd An RDD of JSON string +#' @param schema A StructType object to use as schema +#' @param samplingRatio The ratio of simpling used to infer the schema +#' @return A DataFrame +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' rdd <- texFile(sc, "path/to/json") +#' df <- jsonRDD(sqlCtx, rdd) +#' } + +# TODO: support schema +jsonRDD <- function(sqlCtx, rdd, schema = NULL, samplingRatio = 1.0) { + rdd <- serializeToString(rdd) + if (is.null(schema)) { + sdf <- callJMethod(sqlCtx, "jsonRDD", callJMethod(getJRDD(rdd), "rdd"), samplingRatio) + dataFrame(sdf) + } else { + stop("not implemented") + } +} + + +#' Create a DataFrame from a Parquet file. +#' +#' Loads a Parquet file, returning the result as a DataFrame. +#' +#' @param sqlCtx SQLContext to use +#' @param ... Path(s) of parquet file(s) to read. +#' @return DataFrame +#' @export + +# TODO: Implement saveasParquetFile and write examples for both +parquetFile <- function(sqlCtx, ...) { + # Allow the user to have a more flexible definiton of the text file path + paths <- lapply(list(...), normalizePath) + sdf <- callJMethod(sqlCtx, "parquetFile", paths) + dataFrame(sdf) +} + +#' SQL Query +#' +#' Executes a SQL query using Spark, returning the result as a DataFrame. +#' +#' @param sqlCtx SQLContext to use +#' @param sqlQuery A character vector containing the SQL query +#' @return DataFrame +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' registerTempTable(df, "table") +#' new_df <- sql(sqlCtx, "SELECT * FROM table") +#' } + +sql <- function(sqlCtx, sqlQuery) { + sdf <- callJMethod(sqlCtx, "sql", sqlQuery) + dataFrame(sdf) +} + + +#' Create a DataFrame from a SparkSQL Table +#' +#' Returns the specified Table as a DataFrame. The Table must have already been registered +#' in the SQLContext. +#' +#' @param sqlCtx SQLContext to use +#' @param tableName The SparkSQL Table to convert to a DataFrame. +#' @return DataFrame +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' registerTempTable(df, "table") +#' new_df <- table(sqlCtx, "table") +#' } + +table <- function(sqlCtx, tableName) { + sdf <- callJMethod(sqlCtx, "table", tableName) + dataFrame(sdf) +} + + +#' Tables +#' +#' Returns a DataFrame containing names of tables in the given database. +#' +#' @param sqlCtx SQLContext to use +#' @param databaseName name of the database +#' @return a DataFrame +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' tables(sqlCtx, "hive") +#' } + +tables <- function(sqlCtx, databaseName = NULL) { + jdf <- if (is.null(databaseName)) { + callJMethod(sqlCtx, "tables") + } else { + callJMethod(sqlCtx, "tables", databaseName) + } + dataFrame(jdf) +} + + +#' Table Names +#' +#' Returns the names of tables in the given database as an array. +#' +#' @param sqlCtx SQLContext to use +#' @param databaseName name of the database +#' @return a list of table names +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' tableNames(sqlCtx, "hive") +#' } + +tableNames <- function(sqlCtx, databaseName = NULL) { + if (is.null(databaseName)) { + callJMethod(sqlCtx, "tableNames") + } else { + callJMethod(sqlCtx, "tableNames", databaseName) + } +} + + +#' Cache Table +#' +#' Caches the specified table in-memory. +#' +#' @param sqlCtx SQLContext to use +#' @param tableName The name of the table being cached +#' @return DataFrame +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' registerTempTable(df, "table") +#' cacheTable(sqlCtx, "table") +#' } + +cacheTable <- function(sqlCtx, tableName) { + callJMethod(sqlCtx, "cacheTable", tableName) +} + +#' Uncache Table +#' +#' Removes the specified table from the in-memory cache. +#' +#' @param sqlCtx SQLContext to use +#' @param tableName The name of the table being uncached +#' @return DataFrame +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' registerTempTable(df, "table") +#' uncacheTable(sqlCtx, "table") +#' } + +uncacheTable <- function(sqlCtx, tableName) { + callJMethod(sqlCtx, "uncacheTable", tableName) +} + +#' Clear Cache +#' +#' Removes all cached tables from the in-memory cache. +#' +#' @param sqlCtx SQLContext to use +#' @examples +#' \dontrun{ +#' clearCache(sqlCtx) +#' } + +clearCache <- function(sqlCtx) { + callJMethod(sqlCtx, "clearCache") +} + +#' Drop Temporary Table +#' +#' Drops the temporary table with the given table name in the catalog. +#' If the table has been cached/persisted before, it's also unpersisted. +#' +#' @param sqlCtx SQLContext to use +#' @param tableName The name of the SparkSQL table to be dropped. +#' @examples +#' \dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' df <- loadDF(sqlCtx, path, "parquet") +#' registerTempTable(df, "table") +#' dropTempTable(sqlCtx, "table") +#' } + +dropTempTable <- function(sqlCtx, tableName) { + if (class(tableName) != "character") { + stop("tableName must be a string.") + } + callJMethod(sqlCtx, "dropTempTable", tableName) +} + +#' Load an DataFrame +#' +#' Returns the dataset in a data source as a DataFrame +#' +#' The data source is specified by the `source` and a set of options(...). +#' If `source` is not specified, the default data source configured by +#' "spark.sql.sources.default" will be used. +#' +#' @param sqlCtx SQLContext to use +#' @param path The path of files to load +#' @param source the name of external data source +#' @return DataFrame +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' df <- load(sqlCtx, "path/to/file.json", source = "json") +#' } + +loadDF <- function(sqlCtx, path = NULL, source = NULL, ...) { + options <- varargsToEnv(...) + if (!is.null(path)) { + options[['path']] <- path + } + sdf <- callJMethod(sqlCtx, "load", source, options) + dataFrame(sdf) +} + +#' Create an external table +#' +#' Creates an external table based on the dataset in a data source, +#' Returns the DataFrame associated with the external table. +#' +#' The data source is specified by the `source` and a set of options(...). +#' If `source` is not specified, the default data source configured by +#' "spark.sql.sources.default" will be used. +#' +#' @param sqlCtx SQLContext to use +#' @param tableName A name of the table +#' @param path The path of files to load +#' @param source the name of external data source +#' @return DataFrame +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' df <- sparkRSQL.createExternalTable(sqlCtx, "myjson", path="path/to/json", source="json") +#' } + +createExternalTable <- function(sqlCtx, tableName, path = NULL, source = NULL, ...) { + options <- varargsToEnv(...) + if (!is.null(path)) { + options[['path']] <- path + } + sdf <- callJMethod(sqlCtx, "createExternalTable", tableName, source, options) + dataFrame(sdf) +} diff --git a/R/pkg/R/SQLTypes.R b/R/pkg/R/SQLTypes.R new file mode 100644 index 0000000000000..962fba5b3cf03 --- /dev/null +++ b/R/pkg/R/SQLTypes.R @@ -0,0 +1,64 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Utility functions for handling SparkSQL DataTypes. + +# Handler for StructType +structType <- function(st) { + obj <- structure(new.env(parent = emptyenv()), class = "structType") + obj$jobj <- st + obj$fields <- function() { lapply(callJMethod(st, "fields"), structField) } + obj +} + +#' Print a Spark StructType. +#' +#' This function prints the contents of a StructType returned from the +#' SparkR JVM backend. +#' +#' @param x A StructType object +#' @param ... further arguments passed to or from other methods +print.structType <- function(x, ...) { + fieldsList <- lapply(x$fields(), function(i) { i$print() }) + print(fieldsList) +} + +# Handler for StructField +structField <- function(sf) { + obj <- structure(new.env(parent = emptyenv()), class = "structField") + obj$jobj <- sf + obj$name <- function() { callJMethod(sf, "name") } + obj$dataType <- function() { callJMethod(sf, "dataType") } + obj$dataType.toString <- function() { callJMethod(obj$dataType(), "toString") } + obj$dataType.simpleString <- function() { callJMethod(obj$dataType(), "simpleString") } + obj$nullable <- function() { callJMethod(sf, "nullable") } + obj$print <- function() { paste("StructField(", + paste(obj$name(), obj$dataType.toString(), obj$nullable(), sep = ", "), + ")", sep = "") } + obj +} + +#' Print a Spark StructField. +#' +#' This function prints the contents of a StructField returned from the +#' SparkR JVM backend. +#' +#' @param x A StructField object +#' @param ... further arguments passed to or from other methods +print.structField <- function(x, ...) { + cat(x$print()) +} diff --git a/R/pkg/R/backend.R b/R/pkg/R/backend.R new file mode 100644 index 0000000000000..2fb6fae55f28c --- /dev/null +++ b/R/pkg/R/backend.R @@ -0,0 +1,115 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Methods to call into SparkRBackend. + + +# Returns TRUE if object is an instance of given class +isInstanceOf <- function(jobj, className) { + stopifnot(class(jobj) == "jobj") + cls <- callJStatic("java.lang.Class", "forName", className) + callJMethod(cls, "isInstance", jobj) +} + +# Call a Java method named methodName on the object +# specified by objId. objId should be a "jobj" returned +# from the SparkRBackend. +callJMethod <- function(objId, methodName, ...) { + stopifnot(class(objId) == "jobj") + if (!isValidJobj(objId)) { + stop("Invalid jobj ", objId$id, + ". If SparkR was restarted, Spark operations need to be re-executed.") + } + invokeJava(isStatic = FALSE, objId$id, methodName, ...) +} + +# Call a static method on a specified className +callJStatic <- function(className, methodName, ...) { + invokeJava(isStatic = TRUE, className, methodName, ...) +} + +# Create a new object of the specified class name +newJObject <- function(className, ...) { + invokeJava(isStatic = TRUE, className, methodName = "", ...) +} + +# Remove an object from the SparkR backend. This is done +# automatically when a jobj is garbage collected. +removeJObject <- function(objId) { + invokeJava(isStatic = TRUE, "SparkRHandler", "rm", objId) +} + +isRemoveMethod <- function(isStatic, objId, methodName) { + isStatic == TRUE && objId == "SparkRHandler" && methodName == "rm" +} + +# Invoke a Java method on the SparkR backend. Users +# should typically use one of the higher level methods like +# callJMethod, callJStatic etc. instead of using this. +# +# isStatic - TRUE if the method to be called is static +# objId - String that refers to the object on which method is invoked +# Should be a jobj id for non-static methods and the classname +# for static methods +# methodName - name of method to be invoked +invokeJava <- function(isStatic, objId, methodName, ...) { + if (!exists(".sparkRCon", .sparkREnv)) { + stop("No connection to backend found. Please re-run sparkR.init") + } + + # If this isn't a removeJObject call + if (!isRemoveMethod(isStatic, objId, methodName)) { + objsToRemove <- ls(.toRemoveJobjs) + if (length(objsToRemove) > 0) { + sapply(objsToRemove, + function(e) { + removeJObject(e) + }) + rm(list = objsToRemove, envir = .toRemoveJobjs) + } + } + + + rc <- rawConnection(raw(0), "r+") + + writeBoolean(rc, isStatic) + writeString(rc, objId) + writeString(rc, methodName) + + args <- list(...) + writeInt(rc, length(args)) + writeArgs(rc, args) + + # Construct the whole request message to send it once, + # avoiding write-write-read pattern in case of Nagle's algorithm. + # Refer to http://en.wikipedia.org/wiki/Nagle%27s_algorithm for the details. + bytesToSend <- rawConnectionValue(rc) + close(rc) + rc <- rawConnection(raw(0), "r+") + writeInt(rc, length(bytesToSend)) + writeBin(bytesToSend, rc) + requestMessage <- rawConnectionValue(rc) + close(rc) + + conn <- get(".sparkRCon", .sparkREnv) + writeBin(requestMessage, conn) + + # TODO: check the status code to output error information + returnStatus <- readInt(conn) + stopifnot(returnStatus == 0) + readObject(conn) +} diff --git a/R/pkg/R/broadcast.R b/R/pkg/R/broadcast.R new file mode 100644 index 0000000000000..583fa2e7fdcfd --- /dev/null +++ b/R/pkg/R/broadcast.R @@ -0,0 +1,86 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# S4 class representing Broadcast variables + +# Hidden environment that holds values for broadcast variables +# This will not be serialized / shipped by default +.broadcastNames <- new.env() +.broadcastValues <- new.env() +.broadcastIdToName <- new.env() + +#' @title S4 class that represents a Broadcast variable +#' @description Broadcast variables can be created using the broadcast +#' function from a \code{SparkContext}. +#' @rdname broadcast-class +#' @seealso broadcast +#' +#' @param id Id of the backing Spark broadcast variable +#' @export +setClass("Broadcast", slots = list(id = "character")) + +#' @rdname broadcast-class +#' @param value Value of the broadcast variable +#' @param jBroadcastRef reference to the backing Java broadcast object +#' @param objName name of broadcasted object +#' @export +Broadcast <- function(id, value, jBroadcastRef, objName) { + .broadcastValues[[id]] <- value + .broadcastNames[[as.character(objName)]] <- jBroadcastRef + .broadcastIdToName[[id]] <- as.character(objName) + new("Broadcast", id = id) +} + +#' @description +#' \code{value} can be used to get the value of a broadcast variable inside +#' a distributed function. +#' +#' @param bcast The broadcast variable to get +#' @rdname broadcast +#' @aliases value,Broadcast-method +setMethod("value", + signature(bcast = "Broadcast"), + function(bcast) { + if (exists(bcast@id, envir = .broadcastValues)) { + get(bcast@id, envir = .broadcastValues) + } else { + NULL + } + }) + +#' Internal function to set values of a broadcast variable. +#' +#' This function is used internally by Spark to set the value of a broadcast +#' variable on workers. Not intended for use outside the package. +#' +#' @rdname broadcast-internal +#' @seealso broadcast, value + +#' @param bcastId The id of broadcast variable to set +#' @param value The value to be set +#' @export +setBroadcastValue <- function(bcastId, value) { + bcastIdStr <- as.character(bcastId) + .broadcastValues[[bcastIdStr]] <- value +} + +#' Helper function to clear the list of broadcast variables we know about +#' Should be called when the SparkR JVM backend is shutdown +clearBroadcastVariables <- function() { + bcasts <- ls(.broadcastNames) + rm(list = bcasts, envir = .broadcastNames) +} diff --git a/R/pkg/R/client.R b/R/pkg/R/client.R new file mode 100644 index 0000000000000..1281c41213e32 --- /dev/null +++ b/R/pkg/R/client.R @@ -0,0 +1,57 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Client code to connect to SparkRBackend + +# Creates a SparkR client connection object +# if one doesn't already exist +connectBackend <- function(hostname, port, timeout = 6000) { + if (exists(".sparkRcon", envir = .sparkREnv)) { + if (isOpen(.sparkREnv[[".sparkRCon"]])) { + cat("SparkRBackend client connection already exists\n") + return(get(".sparkRcon", envir = .sparkREnv)) + } + } + + con <- socketConnection(host = hostname, port = port, server = FALSE, + blocking = TRUE, open = "wb", timeout = timeout) + + assign(".sparkRCon", con, envir = .sparkREnv) + con +} + +launchBackend <- function(args, sparkHome, jars, sparkSubmitOpts) { + if (.Platform$OS.type == "unix") { + sparkSubmitBinName = "spark-submit" + } else { + sparkSubmitBinName = "spark-submit.cmd" + } + + if (sparkHome != "") { + sparkSubmitBin <- file.path(sparkHome, "bin", sparkSubmitBinName) + } else { + sparkSubmitBin <- sparkSubmitBinName + } + + if (jars != "") { + jars <- paste("--jars", jars) + } + + combinedArgs <- paste(jars, sparkSubmitOpts, args, sep = " ") + cat("Launching java with spark-submit command", sparkSubmitBin, combinedArgs, "\n") + invisible(system2(sparkSubmitBin, combinedArgs, wait = F)) +} diff --git a/R/pkg/R/column.R b/R/pkg/R/column.R new file mode 100644 index 0000000000000..e196305186b9a --- /dev/null +++ b/R/pkg/R/column.R @@ -0,0 +1,199 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Column Class + +#' @include generics.R jobj.R +NULL + +setOldClass("jobj") + +#' @title S4 class that represents a DataFrame column +#' @description The column class supports unary, binary operations on DataFrame columns + +#' @rdname column +#' +#' @param jc reference to JVM DataFrame column +#' @export +setClass("Column", + slots = list(jc = "jobj")) + +setMethod("initialize", "Column", function(.Object, jc) { + .Object@jc <- jc + .Object +}) + +column <- function(jc) { + new("Column", jc) +} + +col <- function(x) { + column(callJStatic("org.apache.spark.sql.functions", "col", x)) +} + +#' @rdname show +setMethod("show", "Column", + function(object) { + cat("Column", callJMethod(object@jc, "toString"), "\n") + }) + +operators <- list( + "+" = "plus", "-" = "minus", "*" = "multiply", "/" = "divide", "%%" = "mod", + "==" = "equalTo", ">" = "gt", "<" = "lt", "!=" = "notEqual", "<=" = "leq", ">=" = "geq", + # we can not override `&&` and `||`, so use `&` and `|` instead + "&" = "and", "|" = "or" #, "!" = "unary_$bang" +) +column_functions1 <- c("asc", "desc", "isNull", "isNotNull") +column_functions2 <- c("like", "rlike", "startsWith", "endsWith", "getField", "getItem", "contains") +functions <- c("min", "max", "sum", "avg", "mean", "count", "abs", "sqrt", + "first", "last", "lower", "upper", "sumDistinct") + +createOperator <- function(op) { + setMethod(op, + signature(e1 = "Column"), + function(e1, e2) { + jc <- if (missing(e2)) { + if (op == "-") { + callJMethod(e1@jc, "unary_$minus") + } else { + callJMethod(e1@jc, operators[[op]]) + } + } else { + if (class(e2) == "Column") { + e2 <- e2@jc + } + callJMethod(e1@jc, operators[[op]], e2) + } + column(jc) + }) +} + +createColumnFunction1 <- function(name) { + setMethod(name, + signature(x = "Column"), + function(x) { + column(callJMethod(x@jc, name)) + }) +} + +createColumnFunction2 <- function(name) { + setMethod(name, + signature(x = "Column"), + function(x, data) { + if (class(data) == "Column") { + data <- data@jc + } + jc <- callJMethod(x@jc, name, data) + column(jc) + }) +} + +createStaticFunction <- function(name) { + setMethod(name, + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", name, x@jc) + column(jc) + }) +} + +createMethods <- function() { + for (op in names(operators)) { + createOperator(op) + } + for (name in column_functions1) { + createColumnFunction1(name) + } + for (name in column_functions2) { + createColumnFunction2(name) + } + for (x in functions) { + createStaticFunction(x) + } +} + +createMethods() + +#' alias +#' +#' Set a new name for a column +setMethod("alias", + signature(object = "Column"), + function(object, data) { + if (is.character(data)) { + column(callJMethod(object@jc, "as", data)) + } else { + stop("data should be character") + } + }) + +#' An expression that returns a substring. +#' +#' @param start starting position +#' @param stop ending position +setMethod("substr", signature(x = "Column"), + function(x, start, stop) { + jc <- callJMethod(x@jc, "substr", as.integer(start - 1), as.integer(stop - start + 1)) + column(jc) + }) + +#' Casts the column to a different data type. +#' @examples +#' \dontrun{ +#' cast(df$age, "string") +#' cast(df$name, list(type="array", elementType="byte", containsNull = TRUE)) +#' } +setMethod("cast", + signature(x = "Column"), + function(x, dataType) { + if (is.character(dataType)) { + column(callJMethod(x@jc, "cast", dataType)) + } else if (is.list(dataType)) { + json <- tojson(dataType) + jdataType <- callJStatic("org.apache.spark.sql.types.DataType", "fromJson", json) + column(callJMethod(x@jc, "cast", jdataType)) + } else { + stop("dataType should be character or list") + } + }) + +#' Approx Count Distinct +#' +#' Returns the approximate number of distinct items in a group. +#' +setMethod("approxCountDistinct", + signature(x = "Column"), + function(x, rsd = 0.95) { + jc <- callJStatic("org.apache.spark.sql.functions", "approxCountDistinct", x@jc, rsd) + column(jc) + }) + +#' Count Distinct +#' +#' returns the number of distinct items in a group. +#' +setMethod("countDistinct", + signature(x = "Column"), + function(x, ...) { + jcol <- lapply(list(...), function (x) { + x@jc + }) + jc <- callJStatic("org.apache.spark.sql.functions", "countDistinct", x@jc, + listToSeq(jcol)) + column(jc) + }) + diff --git a/R/pkg/R/context.R b/R/pkg/R/context.R new file mode 100644 index 0000000000000..2fc0bb294bcce --- /dev/null +++ b/R/pkg/R/context.R @@ -0,0 +1,225 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# context.R: SparkContext driven functions + +getMinSplits <- function(sc, minSplits) { + if (is.null(minSplits)) { + defaultParallelism <- callJMethod(sc, "defaultParallelism") + minSplits <- min(defaultParallelism, 2) + } + as.integer(minSplits) +} + +#' Create an RDD from a text file. +#' +#' This function reads a text file from HDFS, a local file system (available on all +#' nodes), or any Hadoop-supported file system URI, and creates an +#' RDD of strings from it. +#' +#' @param sc SparkContext to use +#' @param path Path of file to read. A vector of multiple paths is allowed. +#' @param minSplits Minimum number of splits to be created. If NULL, the default +#' value is chosen based on available parallelism. +#' @return RDD where each item is of type \code{character} +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' lines <- textFile(sc, "myfile.txt") +#'} +textFile <- function(sc, path, minSplits = NULL) { + # Allow the user to have a more flexible definiton of the text file path + path <- suppressWarnings(normalizePath(path)) + #' Convert a string vector of paths to a string containing comma separated paths + path <- paste(path, collapse = ",") + + jrdd <- callJMethod(sc, "textFile", path, getMinSplits(sc, minSplits)) + # jrdd is of type JavaRDD[String] + RDD(jrdd, "string") +} + +#' Load an RDD saved as a SequenceFile containing serialized objects. +#' +#' The file to be loaded should be one that was previously generated by calling +#' saveAsObjectFile() of the RDD class. +#' +#' @param sc SparkContext to use +#' @param path Path of file to read. A vector of multiple paths is allowed. +#' @param minSplits Minimum number of splits to be created. If NULL, the default +#' value is chosen based on available parallelism. +#' @return RDD containing serialized R objects. +#' @seealso saveAsObjectFile +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- objectFile(sc, "myfile") +#'} +objectFile <- function(sc, path, minSplits = NULL) { + # Allow the user to have a more flexible definiton of the text file path + path <- suppressWarnings(normalizePath(path)) + #' Convert a string vector of paths to a string containing comma separated paths + path <- paste(path, collapse = ",") + + jrdd <- callJMethod(sc, "objectFile", path, getMinSplits(sc, minSplits)) + # Assume the RDD contains serialized R objects. + RDD(jrdd, "byte") +} + +#' Create an RDD from a homogeneous list or vector. +#' +#' This function creates an RDD from a local homogeneous list in R. The elements +#' in the list are split into \code{numSlices} slices and distributed to nodes +#' in the cluster. +#' +#' @param sc SparkContext to use +#' @param coll collection to parallelize +#' @param numSlices number of partitions to create in the RDD +#' @return an RDD created from this collection +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10, 2) +#' # The RDD should contain 10 elements +#' length(rdd) +#'} +parallelize <- function(sc, coll, numSlices = 1) { + # TODO: bound/safeguard numSlices + # TODO: unit tests for if the split works for all primitives + # TODO: support matrix, data frame, etc + if ((!is.list(coll) && !is.vector(coll)) || is.data.frame(coll)) { + if (is.data.frame(coll)) { + message(paste("context.R: A data frame is parallelized by columns.")) + } else { + if (is.matrix(coll)) { + message(paste("context.R: A matrix is parallelized by elements.")) + } else { + message(paste("context.R: parallelize() currently only supports lists and vectors.", + "Calling as.list() to coerce coll into a list.")) + } + } + coll <- as.list(coll) + } + + if (numSlices > length(coll)) + numSlices <- length(coll) + + sliceLen <- ceiling(length(coll) / numSlices) + slices <- split(coll, rep(1:(numSlices + 1), each = sliceLen)[1:length(coll)]) + + # Serialize each slice: obtain a list of raws, or a list of lists (slices) of + # 2-tuples of raws + serializedSlices <- lapply(slices, serialize, connection = NULL) + + jrdd <- callJStatic("org.apache.spark.api.r.RRDD", + "createRDDFromArray", sc, serializedSlices) + + RDD(jrdd, "byte") +} + +#' Include this specified package on all workers +#' +#' This function can be used to include a package on all workers before the +#' user's code is executed. This is useful in scenarios where other R package +#' functions are used in a function passed to functions like \code{lapply}. +#' NOTE: The package is assumed to be installed on every node in the Spark +#' cluster. +#' +#' @param sc SparkContext to use +#' @param pkg Package name +#' +#' @export +#' @examples +#'\dontrun{ +#' library(Matrix) +#' +#' sc <- sparkR.init() +#' # Include the matrix library we will be using +#' includePackage(sc, Matrix) +#' +#' generateSparse <- function(x) { +#' sparseMatrix(i=c(1, 2, 3), j=c(1, 2, 3), x=c(1, 2, 3)) +#' } +#' +#' rdd <- lapplyPartition(parallelize(sc, 1:2, 2L), generateSparse) +#' collect(rdd) +#'} +includePackage <- function(sc, pkg) { + pkg <- as.character(substitute(pkg)) + if (exists(".packages", .sparkREnv)) { + packages <- .sparkREnv$.packages + } else { + packages <- list() + } + packages <- c(packages, pkg) + .sparkREnv$.packages <- packages +} + +#' @title Broadcast a variable to all workers +#' +#' @description +#' Broadcast a read-only variable to the cluster, returning a \code{Broadcast} +#' object for reading it in distributed functions. +#' +#' @param sc Spark Context to use +#' @param object Object to be broadcast +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:2, 2L) +#' +#' # Large Matrix object that we want to broadcast +#' randomMat <- matrix(nrow=100, ncol=10, data=rnorm(1000)) +#' randomMatBr <- broadcast(sc, randomMat) +#' +#' # Use the broadcast variable inside the function +#' useBroadcast <- function(x) { +#' sum(value(randomMatBr) * x) +#' } +#' sumRDD <- lapply(rdd, useBroadcast) +#'} +broadcast <- function(sc, object) { + objName <- as.character(substitute(object)) + serializedObj <- serialize(object, connection = NULL) + + jBroadcast <- callJMethod(sc, "broadcast", serializedObj) + id <- as.character(callJMethod(jBroadcast, "id")) + + Broadcast(id, object, jBroadcast, objName) +} + +#' @title Set the checkpoint directory +#' +#' Set the directory under which RDDs are going to be checkpointed. The +#' directory must be a HDFS path if running on a cluster. +#' +#' @param sc Spark Context to use +#' @param dirName Directory path +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' setCheckpointDir(sc, "~/checkpoints") +#' rdd <- parallelize(sc, 1:2, 2L) +#' checkpoint(rdd) +#'} +setCheckpointDir <- function(sc, dirName) { + invisible(callJMethod(sc, "setCheckpointDir", suppressWarnings(normalizePath(dirName)))) +} diff --git a/R/pkg/R/deserialize.R b/R/pkg/R/deserialize.R new file mode 100644 index 0000000000000..257b435607ce8 --- /dev/null +++ b/R/pkg/R/deserialize.R @@ -0,0 +1,184 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Utility functions to deserialize objects from Java. + +# Type mapping from Java to R +# +# void -> NULL +# Int -> integer +# String -> character +# Boolean -> logical +# Double -> double +# Long -> double +# Array[Byte] -> raw +# Date -> Date +# Time -> POSIXct +# +# Array[T] -> list() +# Object -> jobj + +readObject <- function(con) { + # Read type first + type <- readType(con) + readTypedObject(con, type) +} + +readTypedObject <- function(con, type) { + switch (type, + "i" = readInt(con), + "c" = readString(con), + "b" = readBoolean(con), + "d" = readDouble(con), + "r" = readRaw(con), + "D" = readDate(con), + "t" = readTime(con), + "l" = readList(con), + "n" = NULL, + "j" = getJobj(readString(con)), + stop(paste("Unsupported type for deserialization", type))) +} + +readString <- function(con) { + stringLen <- readInt(con) + string <- readBin(con, raw(), stringLen, endian = "big") + rawToChar(string) +} + +readInt <- function(con) { + readBin(con, integer(), n = 1, endian = "big") +} + +readDouble <- function(con) { + readBin(con, double(), n = 1, endian = "big") +} + +readBoolean <- function(con) { + as.logical(readInt(con)) +} + +readType <- function(con) { + rawToChar(readBin(con, "raw", n = 1L)) +} + +readDate <- function(con) { + as.Date(readString(con)) +} + +readTime <- function(con) { + t <- readDouble(con) + as.POSIXct(t, origin = "1970-01-01") +} + +# We only support lists where all elements are of same type +readList <- function(con) { + type <- readType(con) + len <- readInt(con) + if (len > 0) { + l <- vector("list", len) + for (i in 1:len) { + l[[i]] <- readTypedObject(con, type) + } + l + } else { + list() + } +} + +readRaw <- function(con) { + dataLen <- readInt(con) + data <- readBin(con, raw(), as.integer(dataLen), endian = "big") +} + +readRawLen <- function(con, dataLen) { + data <- readBin(con, raw(), as.integer(dataLen), endian = "big") +} + +readDeserialize <- function(con) { + # We have two cases that are possible - In one, the entire partition is + # encoded as a byte array, so we have only one value to read. If so just + # return firstData + dataLen <- readInt(con) + firstData <- unserialize( + readBin(con, raw(), as.integer(dataLen), endian = "big")) + + # Else, read things into a list + dataLen <- readInt(con) + if (length(dataLen) > 0 && dataLen > 0) { + data <- list(firstData) + while (length(dataLen) > 0 && dataLen > 0) { + data[[length(data) + 1L]] <- unserialize( + readBin(con, raw(), as.integer(dataLen), endian = "big")) + dataLen <- readInt(con) + } + unlist(data, recursive = FALSE) + } else { + firstData + } +} + +readDeserializeRows <- function(inputCon) { + # readDeserializeRows will deserialize a DataOutputStream composed of + # a list of lists. Since the DOS is one continuous stream and + # the number of rows varies, we put the readRow function in a while loop + # that termintates when the next row is empty. + data <- list() + while(TRUE) { + row <- readRow(inputCon) + if (length(row) == 0) { + break + } + data[[length(data) + 1L]] <- row + } + data # this is a list of named lists now +} + +readRowList <- function(obj) { + # readRowList is meant for use inside an lapply. As a result, it is + # necessary to open a standalone connection for the row and consume + # the numCols bytes inside the read function in order to correctly + # deserialize the row. + rawObj <- rawConnection(obj, "r+") + on.exit(close(rawObj)) + readRow(rawObj) +} + +readRow <- function(inputCon) { + numCols <- readInt(inputCon) + if (length(numCols) > 0 && numCols > 0) { + lapply(1:numCols, function(x) { + obj <- readObject(inputCon) + if (is.null(obj)) { + NA + } else { + obj + } + }) # each row is a list now + } else { + list() + } +} + +# Take a single column as Array[Byte] and deserialize it into an atomic vector +readCol <- function(inputCon, numRows) { + # sapply can not work with POSIXlt + do.call(c, lapply(1:numRows, function(x) { + value <- readObject(inputCon) + # Replace NULL with NA so we can coerce to vectors + if (is.null(value)) NA else value + })) +} diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R new file mode 100644 index 0000000000000..5fb1ccaa84ee2 --- /dev/null +++ b/R/pkg/R/generics.R @@ -0,0 +1,543 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +############ RDD Actions and Transformations ############ + +#' @rdname aggregateRDD +#' @seealso reduce +#' @export +setGeneric("aggregateRDD", function(x, zeroValue, seqOp, combOp) { standardGeneric("aggregateRDD") }) + +#' @rdname cache-methods +#' @export +setGeneric("cache", function(x) { standardGeneric("cache") }) + +#' @rdname coalesce +#' @seealso repartition +#' @export +setGeneric("coalesce", function(x, numPartitions, ...) { standardGeneric("coalesce") }) + +#' @rdname checkpoint-methods +#' @export +setGeneric("checkpoint", function(x) { standardGeneric("checkpoint") }) + +#' @rdname collect-methods +#' @export +setGeneric("collect", function(x, ...) { standardGeneric("collect") }) + +#' @rdname collect-methods +#' @export +setGeneric("collectAsMap", function(x) { standardGeneric("collectAsMap") }) + +#' @rdname collect-methods +#' @export +setGeneric("collectPartition", + function(x, partitionId) { + standardGeneric("collectPartition") + }) + +#' @rdname count +#' @export +setGeneric("count", function(x) { standardGeneric("count") }) + +#' @rdname countByValue +#' @export +setGeneric("countByValue", function(x) { standardGeneric("countByValue") }) + +#' @rdname distinct +#' @export +setGeneric("distinct", function(x, numPartitions = 1L) { standardGeneric("distinct") }) + +#' @rdname filterRDD +#' @export +setGeneric("filterRDD", function(x, f) { standardGeneric("filterRDD") }) + +#' @rdname first +#' @export +setGeneric("first", function(x) { standardGeneric("first") }) + +#' @rdname flatMap +#' @export +setGeneric("flatMap", function(X, FUN) { standardGeneric("flatMap") }) + +#' @rdname fold +#' @seealso reduce +#' @export +setGeneric("fold", function(x, zeroValue, op) { standardGeneric("fold") }) + +#' @rdname foreach +#' @export +setGeneric("foreach", function(x, func) { standardGeneric("foreach") }) + +#' @rdname foreach +#' @export +setGeneric("foreachPartition", function(x, func) { standardGeneric("foreachPartition") }) + +# The jrdd accessor function. +setGeneric("getJRDD", function(rdd, ...) { standardGeneric("getJRDD") }) + +#' @rdname glom +#' @export +setGeneric("glom", function(x) { standardGeneric("glom") }) + +#' @rdname keyBy +#' @export +setGeneric("keyBy", function(x, func) { standardGeneric("keyBy") }) + +#' @rdname lapplyPartition +#' @export +setGeneric("lapplyPartition", function(X, FUN) { standardGeneric("lapplyPartition") }) + +#' @rdname lapplyPartitionsWithIndex +#' @export +setGeneric("lapplyPartitionsWithIndex", + function(X, FUN) { + standardGeneric("lapplyPartitionsWithIndex") + }) + +#' @rdname lapply +#' @export +setGeneric("map", function(X, FUN) { standardGeneric("map") }) + +#' @rdname lapplyPartition +#' @export +setGeneric("mapPartitions", function(X, FUN) { standardGeneric("mapPartitions") }) + +#' @rdname lapplyPartitionsWithIndex +#' @export +setGeneric("mapPartitionsWithIndex", + function(X, FUN) { standardGeneric("mapPartitionsWithIndex") }) + +#' @rdname maximum +#' @export +setGeneric("maximum", function(x) { standardGeneric("maximum") }) + +#' @rdname minimum +#' @export +setGeneric("minimum", function(x) { standardGeneric("minimum") }) + +#' @rdname sumRDD +#' @export +setGeneric("sumRDD", function(x) { standardGeneric("sumRDD") }) + +#' @rdname name +#' @export +setGeneric("name", function(x) { standardGeneric("name") }) + +#' @rdname numPartitions +#' @export +setGeneric("numPartitions", function(x) { standardGeneric("numPartitions") }) + +#' @rdname persist +#' @export +setGeneric("persist", function(x, newLevel) { standardGeneric("persist") }) + +#' @rdname pipeRDD +#' @export +setGeneric("pipeRDD", function(x, command, env = list()) { standardGeneric("pipeRDD")}) + +#' @rdname reduce +#' @export +setGeneric("reduce", function(x, func) { standardGeneric("reduce") }) + +#' @rdname repartition +#' @seealso coalesce +#' @export +setGeneric("repartition", function(x, numPartitions) { standardGeneric("repartition") }) + +#' @rdname sampleRDD +#' @export +setGeneric("sampleRDD", + function(x, withReplacement, fraction, seed) { + standardGeneric("sampleRDD") + }) + +#' @rdname saveAsObjectFile +#' @seealso objectFile +#' @export +setGeneric("saveAsObjectFile", function(x, path) { standardGeneric("saveAsObjectFile") }) + +#' @rdname saveAsTextFile +#' @export +setGeneric("saveAsTextFile", function(x, path) { standardGeneric("saveAsTextFile") }) + +#' @rdname setName +#' @export +setGeneric("setName", function(x, name) { standardGeneric("setName") }) + +#' @rdname sortBy +#' @export +setGeneric("sortBy", + function(x, func, ascending = TRUE, numPartitions = 1L) { + standardGeneric("sortBy") + }) + +#' @rdname take +#' @export +setGeneric("take", function(x, num) { standardGeneric("take") }) + +#' @rdname takeOrdered +#' @export +setGeneric("takeOrdered", function(x, num) { standardGeneric("takeOrdered") }) + +#' @rdname takeSample +#' @export +setGeneric("takeSample", + function(x, withReplacement, num, seed) { + standardGeneric("takeSample") + }) + +#' @rdname top +#' @export +setGeneric("top", function(x, num) { standardGeneric("top") }) + +#' @rdname unionRDD +#' @export +setGeneric("unionRDD", function(x, y) { standardGeneric("unionRDD") }) + +#' @rdname unpersist-methods +#' @export +setGeneric("unpersist", function(x, ...) { standardGeneric("unpersist") }) + +#' @rdname zipRDD +#' @export +setGeneric("zipRDD", function(x, other) { standardGeneric("zipRDD") }) + +#' @rdname zipWithIndex +#' @seealso zipWithUniqueId +#' @export +setGeneric("zipWithIndex", function(x) { standardGeneric("zipWithIndex") }) + +#' @rdname zipWithUniqueId +#' @seealso zipWithIndex +#' @export +setGeneric("zipWithUniqueId", function(x) { standardGeneric("zipWithUniqueId") }) + + +############ Binary Functions ############# + +#' @rdname countByKey +#' @export +setGeneric("countByKey", function(x) { standardGeneric("countByKey") }) + +#' @rdname flatMapValues +#' @export +setGeneric("flatMapValues", function(X, FUN) { standardGeneric("flatMapValues") }) + +#' @rdname keys +#' @export +setGeneric("keys", function(x) { standardGeneric("keys") }) + +#' @rdname lookup +#' @export +setGeneric("lookup", function(x, key) { standardGeneric("lookup") }) + +#' @rdname mapValues +#' @export +setGeneric("mapValues", function(X, FUN) { standardGeneric("mapValues") }) + +#' @rdname values +#' @export +setGeneric("values", function(x) { standardGeneric("values") }) + + + +############ Shuffle Functions ############ + +#' @rdname aggregateByKey +#' @seealso foldByKey, combineByKey +#' @export +setGeneric("aggregateByKey", + function(x, zeroValue, seqOp, combOp, numPartitions) { + standardGeneric("aggregateByKey") + }) + +#' @rdname cogroup +#' @export +setGeneric("cogroup", + function(..., numPartitions) { + standardGeneric("cogroup") + }, + signature = "...") + +#' @rdname combineByKey +#' @seealso groupByKey, reduceByKey +#' @export +setGeneric("combineByKey", + function(x, createCombiner, mergeValue, mergeCombiners, numPartitions) { + standardGeneric("combineByKey") + }) + +#' @rdname foldByKey +#' @seealso aggregateByKey, combineByKey +#' @export +setGeneric("foldByKey", + function(x, zeroValue, func, numPartitions) { + standardGeneric("foldByKey") + }) + +#' @rdname join-methods +#' @export +setGeneric("fullOuterJoin", function(x, y, numPartitions) { standardGeneric("fullOuterJoin") }) + +#' @rdname groupByKey +#' @seealso reduceByKey +#' @export +setGeneric("groupByKey", function(x, numPartitions) { standardGeneric("groupByKey") }) + +#' @rdname join-methods +#' @export +setGeneric("join", function(x, y, ...) { standardGeneric("join") }) + +#' @rdname join-methods +#' @export +setGeneric("leftOuterJoin", function(x, y, numPartitions) { standardGeneric("leftOuterJoin") }) + +#' @rdname partitionBy +#' @export +setGeneric("partitionBy", function(x, numPartitions, ...) { standardGeneric("partitionBy") }) + +#' @rdname reduceByKey +#' @seealso groupByKey +#' @export +setGeneric("reduceByKey", function(x, combineFunc, numPartitions) { standardGeneric("reduceByKey")}) + +#' @rdname reduceByKeyLocally +#' @seealso reduceByKey +#' @export +setGeneric("reduceByKeyLocally", + function(x, combineFunc) { + standardGeneric("reduceByKeyLocally") + }) + +#' @rdname join-methods +#' @export +setGeneric("rightOuterJoin", function(x, y, numPartitions) { standardGeneric("rightOuterJoin") }) + +#' @rdname sortByKey +#' @export +setGeneric("sortByKey", function(x, ascending = TRUE, numPartitions = 1L) { + standardGeneric("sortByKey") +}) + + +################### Broadcast Variable Methods ################# + +#' @rdname broadcast +#' @export +setGeneric("value", function(bcast) { standardGeneric("value") }) + + + +#################### DataFrame Methods ######################## + +#' @rdname schema +#' @export +setGeneric("columns", function(x) {standardGeneric("columns") }) + +#' @rdname schema +#' @export +setGeneric("dtypes", function(x) { standardGeneric("dtypes") }) + +#' @rdname explain +#' @export +setGeneric("explain", function(x, ...) { standardGeneric("explain") }) + +#' @rdname filter +#' @export +setGeneric("filter", function(x, condition) { standardGeneric("filter") }) + +#' @rdname DataFrame +#' @export +setGeneric("groupBy", function(x, ...) { standardGeneric("groupBy") }) + +#' @rdname insertInto +#' @export +setGeneric("insertInto", function(x, tableName, ...) { standardGeneric("insertInto") }) + +#' @rdname intersect +#' @export +setGeneric("intersect", function(x, y) { standardGeneric("intersect") }) + +#' @rdname isLocal +#' @export +setGeneric("isLocal", function(x) { standardGeneric("isLocal") }) + +#' @rdname limit +#' @export +setGeneric("limit", function(x, num) {standardGeneric("limit") }) + +#' @rdname sortDF +#' @export +setGeneric("orderBy", function(x, col) { standardGeneric("orderBy") }) + +#' @rdname schema +#' @export +setGeneric("printSchema", function(x) { standardGeneric("printSchema") }) + +#' @rdname registerTempTable +#' @export +setGeneric("registerTempTable", function(x, tableName) { standardGeneric("registerTempTable") }) + +#' @rdname sampleDF +#' @export +setGeneric("sampleDF", + function(x, withReplacement, fraction, seed) { + standardGeneric("sampleDF") + }) + +#' @rdname saveAsParquetFile +#' @export +setGeneric("saveAsParquetFile", function(x, path) { standardGeneric("saveAsParquetFile") }) + +#' @rdname saveAsTable +#' @export +setGeneric("saveAsTable", function(df, tableName, source, mode, ...) { + standardGeneric("saveAsTable") +}) + +#' @rdname saveAsTable +#' @export +setGeneric("saveDF", function(df, path, source, mode, ...) { standardGeneric("saveDF") }) + +#' @rdname schema +#' @export +setGeneric("schema", function(x) { standardGeneric("schema") }) + +#' @rdname select +#' @export +setGeneric("select", function(x, col, ...) { standardGeneric("select") } ) + +#' @rdname select +#' @export +setGeneric("selectExpr", function(x, expr, ...) { standardGeneric("selectExpr") }) + +#' @rdname showDF +#' @export +setGeneric("showDF", function(x,...) { standardGeneric("showDF") }) + +#' @rdname sortDF +#' @export +setGeneric("sortDF", function(x, col, ...) { standardGeneric("sortDF") }) + +#' @rdname subtract +#' @export +setGeneric("subtract", function(x, y) { standardGeneric("subtract") }) + +#' @rdname tojson +#' @export +setGeneric("toJSON", function(x) { standardGeneric("toJSON") }) + +#' @rdname DataFrame +#' @export +setGeneric("toRDD", function(x) { standardGeneric("toRDD") }) + +#' @rdname unionAll +#' @export +setGeneric("unionAll", function(x, y) { standardGeneric("unionAll") }) + +#' @rdname filter +#' @export +setGeneric("where", function(x, condition) { standardGeneric("where") }) + +#' @rdname withColumn +#' @export +setGeneric("withColumn", function(x, colName, col) { standardGeneric("withColumn") }) + +#' @rdname withColumnRenamed +#' @export +setGeneric("withColumnRenamed", function(x, existingCol, newCol) { + standardGeneric("withColumnRenamed") }) + + +###################### Column Methods ########################## + +#' @rdname column +#' @export +setGeneric("approxCountDistinct", function(x, ...) { standardGeneric("approxCountDistinct") }) + +#' @rdname column +#' @export +setGeneric("asc", function(x) { standardGeneric("asc") }) + +#' @rdname column +#' @export +setGeneric("avg", function(x, ...) { standardGeneric("avg") }) + +#' @rdname column +#' @export +setGeneric("cast", function(x, dataType) { standardGeneric("cast") }) + +#' @rdname column +#' @export +setGeneric("contains", function(x, ...) { standardGeneric("contains") }) +#' @rdname column +#' @export +setGeneric("countDistinct", function(x, ...) { standardGeneric("countDistinct") }) + +#' @rdname column +#' @export +setGeneric("desc", function(x) { standardGeneric("desc") }) + +#' @rdname column +#' @export +setGeneric("endsWith", function(x, ...) { standardGeneric("endsWith") }) + +#' @rdname column +#' @export +setGeneric("getField", function(x, ...) { standardGeneric("getField") }) + +#' @rdname column +#' @export +setGeneric("getItem", function(x, ...) { standardGeneric("getItem") }) + +#' @rdname column +#' @export +setGeneric("isNull", function(x) { standardGeneric("isNull") }) + +#' @rdname column +#' @export +setGeneric("isNotNull", function(x) { standardGeneric("isNotNull") }) + +#' @rdname column +#' @export +setGeneric("last", function(x) { standardGeneric("last") }) + +#' @rdname column +#' @export +setGeneric("like", function(x, ...) { standardGeneric("like") }) + +#' @rdname column +#' @export +setGeneric("lower", function(x) { standardGeneric("lower") }) + +#' @rdname column +#' @export +setGeneric("rlike", function(x, ...) { standardGeneric("rlike") }) + +#' @rdname column +#' @export +setGeneric("startsWith", function(x, ...) { standardGeneric("startsWith") }) + +#' @rdname column +#' @export +setGeneric("sumDistinct", function(x) { standardGeneric("sumDistinct") }) + +#' @rdname column +#' @export +setGeneric("upper", function(x) { standardGeneric("upper") }) + diff --git a/R/pkg/R/group.R b/R/pkg/R/group.R new file mode 100644 index 0000000000000..09fc0a7abe48a --- /dev/null +++ b/R/pkg/R/group.R @@ -0,0 +1,132 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# group.R - GroupedData class and methods implemented in S4 OO classes + +setOldClass("jobj") + +#' @title S4 class that represents a GroupedData +#' @description GroupedDatas can be created using groupBy() on a DataFrame +#' @rdname GroupedData +#' @seealso groupBy +#' +#' @param sgd A Java object reference to the backing Scala GroupedData +#' @export +setClass("GroupedData", + slots = list(sgd = "jobj")) + +setMethod("initialize", "GroupedData", function(.Object, sgd) { + .Object@sgd <- sgd + .Object +}) + +#' @rdname DataFrame +groupedData <- function(sgd) { + new("GroupedData", sgd) +} + + +#' @rdname show +setMethod("show", "GroupedData", + function(object) { + cat("GroupedData\n") + }) + +#' Count +#' +#' Count the number of rows for each group. +#' The resulting DataFrame will also contain the grouping columns. +#' +#' @param x a GroupedData +#' @return a DataFrame +#' @export +#' @examples +#' \dontrun{ +#' count(groupBy(df, "name")) +#' } +setMethod("count", + signature(x = "GroupedData"), + function(x) { + dataFrame(callJMethod(x@sgd, "count")) + }) + +#' Agg +#' +#' Aggregates on the entire DataFrame without groups. +#' The resulting DataFrame will also contain the grouping columns. +#' +#' df2 <- agg(df, = ) +#' df2 <- agg(df, newColName = aggFunction(column)) +#' +#' @param x a GroupedData +#' @return a DataFrame +#' @rdname agg +#' @examples +#' \dontrun{ +#' df2 <- agg(df, age = "sum") # new column name will be created as 'SUM(age#0)' +#' df2 <- agg(df, ageSum = sum(df$age)) # Creates a new column named ageSum +#' } +setGeneric("agg", function (x, ...) { standardGeneric("agg") }) + +setMethod("agg", + signature(x = "GroupedData"), + function(x, ...) { + cols = list(...) + stopifnot(length(cols) > 0) + if (is.character(cols[[1]])) { + cols <- varargsToEnv(...) + sdf <- callJMethod(x@sgd, "agg", cols) + } else if (class(cols[[1]]) == "Column") { + ns <- names(cols) + if (!is.null(ns)) { + for (n in ns) { + if (n != "") { + cols[[n]] = alias(cols[[n]], n) + } + } + } + jcols <- lapply(cols, function(c) { c@jc }) + # the GroupedData.agg(col, cols*) API does not contain grouping Column + sdf <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "aggWithGrouping", + x@sgd, listToSeq(jcols)) + } else { + stop("agg can only support Column or character") + } + dataFrame(sdf) + }) + + +# sum/mean/avg/min/max +methods <- c("sum", "mean", "avg", "min", "max") + +createMethod <- function(name) { + setMethod(name, + signature(x = "GroupedData"), + function(x, ...) { + sdf <- callJMethod(x@sgd, name, toSeq(...)) + dataFrame(sdf) + }) +} + +createMethods <- function() { + for (name in methods) { + createMethod(name) + } +} + +createMethods() + diff --git a/R/pkg/R/jobj.R b/R/pkg/R/jobj.R new file mode 100644 index 0000000000000..4180f146b7fbc --- /dev/null +++ b/R/pkg/R/jobj.R @@ -0,0 +1,101 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# References to objects that exist on the JVM backend +# are maintained using the jobj. + +# Maintain a reference count of Java object references +# This allows us to GC the java object when it is safe +.validJobjs <- new.env(parent = emptyenv()) + +# List of object ids to be removed +.toRemoveJobjs <- new.env(parent = emptyenv()) + +# Check if jobj was created with the current SparkContext +isValidJobj <- function(jobj) { + if (exists(".scStartTime", envir = .sparkREnv)) { + jobj$appId == get(".scStartTime", envir = .sparkREnv) + } else { + FALSE + } +} + +getJobj <- function(objId) { + newObj <- jobj(objId) + if (exists(objId, .validJobjs)) { + .validJobjs[[objId]] <- .validJobjs[[objId]] + 1 + } else { + .validJobjs[[objId]] <- 1 + } + newObj +} + +# Handler for a java object that exists on the backend. +jobj <- function(objId) { + if (!is.character(objId)) { + stop("object id must be a character") + } + # NOTE: We need a new env for a jobj as we can only register + # finalizers for environments or external references pointers. + obj <- structure(new.env(parent = emptyenv()), class = "jobj") + obj$id <- objId + obj$appId <- get(".scStartTime", envir = .sparkREnv) + + # Register a finalizer to remove the Java object when this reference + # is garbage collected in R + reg.finalizer(obj, cleanup.jobj) + obj +} + +#' Print a JVM object reference. +#' +#' This function prints the type and id for an object stored +#' in the SparkR JVM backend. +#' +#' @param x The JVM object reference +#' @param ... further arguments passed to or from other methods +print.jobj <- function(x, ...) { + cls <- callJMethod(x, "getClass") + name <- callJMethod(cls, "getName") + cat("Java ref type", name, "id", x$id, "\n", sep = " ") +} + +cleanup.jobj <- function(jobj) { + if (isValidJobj(jobj)) { + objId <- jobj$id + # If we don't know anything about this jobj, ignore it + if (exists(objId, envir = .validJobjs)) { + .validJobjs[[objId]] <- .validJobjs[[objId]] - 1 + + if (.validJobjs[[objId]] == 0) { + rm(list = objId, envir = .validJobjs) + # NOTE: We cannot call removeJObject here as the finalizer may be run + # in the middle of another RPC. Thus we queue up this object Id to be removed + # and then run all the removeJObject when the next RPC is called. + .toRemoveJobjs[[objId]] <- 1 + } + } + } +} + +clearJobjs <- function() { + valid <- ls(.validJobjs) + rm(list = valid, envir = .validJobjs) + + removeList <- ls(.toRemoveJobjs) + rm(list = removeList, envir = .toRemoveJobjs) +} diff --git a/R/pkg/R/pairRDD.R b/R/pkg/R/pairRDD.R new file mode 100644 index 0000000000000..c2396c32a7548 --- /dev/null +++ b/R/pkg/R/pairRDD.R @@ -0,0 +1,789 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Operations supported on RDDs contains pairs (i.e key, value) + +############ Actions and Transformations ############ + +#' Look up elements of a key in an RDD +#' +#' @description +#' \code{lookup} returns a list of values in this RDD for key key. +#' +#' @param x The RDD to collect +#' @param key The key to look up for +#' @return a list of values in this RDD for key key +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' pairs <- list(c(1, 1), c(2, 2), c(1, 3)) +#' rdd <- parallelize(sc, pairs) +#' lookup(rdd, 1) # list(1, 3) +#'} +#' @rdname lookup +#' @aliases lookup,RDD-method +setMethod("lookup", + signature(x = "RDD", key = "ANY"), + function(x, key) { + partitionFunc <- function(part) { + filtered <- part[unlist(lapply(part, function(i) { identical(key, i[[1]]) }))] + lapply(filtered, function(i) { i[[2]] }) + } + valsRDD <- lapplyPartition(x, partitionFunc) + collect(valsRDD) + }) + +#' Count the number of elements for each key, and return the result to the +#' master as lists of (key, count) pairs. +#' +#' Same as countByKey in Spark. +#' +#' @param x The RDD to count keys. +#' @return list of (key, count) pairs, where count is number of each key in rdd. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(c("a", 1), c("b", 1), c("a", 1))) +#' countByKey(rdd) # ("a", 2L), ("b", 1L) +#'} +#' @rdname countByKey +#' @aliases countByKey,RDD-method +setMethod("countByKey", + signature(x = "RDD"), + function(x) { + keys <- lapply(x, function(item) { item[[1]] }) + countByValue(keys) + }) + +#' Return an RDD with the keys of each tuple. +#' +#' @param x The RDD from which the keys of each tuple is returned. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(list(1, 2), list(3, 4))) +#' collect(keys(rdd)) # list(1, 3) +#'} +#' @rdname keys +#' @aliases keys,RDD +setMethod("keys", + signature(x = "RDD"), + function(x) { + func <- function(k) { + k[[1]] + } + lapply(x, func) + }) + +#' Return an RDD with the values of each tuple. +#' +#' @param x The RDD from which the values of each tuple is returned. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(list(1, 2), list(3, 4))) +#' collect(values(rdd)) # list(2, 4) +#'} +#' @rdname values +#' @aliases values,RDD +setMethod("values", + signature(x = "RDD"), + function(x) { + func <- function(v) { + v[[2]] + } + lapply(x, func) + }) + +#' Applies a function to all values of the elements, without modifying the keys. +#' +#' The same as `mapValues()' in Spark. +#' +#' @param X The RDD to apply the transformation. +#' @param FUN the transformation to apply on the value of each element. +#' @return a new RDD created by the transformation. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' makePairs <- lapply(rdd, function(x) { list(x, x) }) +#' collect(mapValues(makePairs, function(x) { x * 2) }) +#' Output: list(list(1,2), list(2,4), list(3,6), ...) +#'} +#' @rdname mapValues +#' @aliases mapValues,RDD,function-method +setMethod("mapValues", + signature(X = "RDD", FUN = "function"), + function(X, FUN) { + func <- function(x) { + list(x[[1]], FUN(x[[2]])) + } + lapply(X, func) + }) + +#' Pass each value in the key-value pair RDD through a flatMap function without +#' changing the keys; this also retains the original RDD's partitioning. +#' +#' The same as 'flatMapValues()' in Spark. +#' +#' @param X The RDD to apply the transformation. +#' @param FUN the transformation to apply on the value of each element. +#' @return a new RDD created by the transformation. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(list(1, c(1,2)), list(2, c(3,4)))) +#' collect(flatMapValues(rdd, function(x) { x })) +#' Output: list(list(1,1), list(1,2), list(2,3), list(2,4)) +#'} +#' @rdname flatMapValues +#' @aliases flatMapValues,RDD,function-method +setMethod("flatMapValues", + signature(X = "RDD", FUN = "function"), + function(X, FUN) { + flatMapFunc <- function(x) { + lapply(FUN(x[[2]]), function(v) { list(x[[1]], v) }) + } + flatMap(X, flatMapFunc) + }) + +############ Shuffle Functions ############ + +#' Partition an RDD by key +#' +#' This function operates on RDDs where every element is of the form list(K, V) or c(K, V). +#' For each element of this RDD, the partitioner is used to compute a hash +#' function and the RDD is partitioned using this hash value. +#' +#' @param x The RDD to partition. Should be an RDD where each element is +#' list(K, V) or c(K, V). +#' @param numPartitions Number of partitions to create. +#' @param ... Other optional arguments to partitionBy. +#' +#' @param partitionFunc The partition function to use. Uses a default hashCode +#' function if not provided +#' @return An RDD partitioned using the specified partitioner. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' pairs <- list(list(1, 2), list(1.1, 3), list(1, 4)) +#' rdd <- parallelize(sc, pairs) +#' parts <- partitionBy(rdd, 2L) +#' collectPartition(parts, 0L) # First partition should contain list(1, 2) and list(1, 4) +#'} +#' @rdname partitionBy +#' @aliases partitionBy,RDD,integer-method +setMethod("partitionBy", + signature(x = "RDD", numPartitions = "integer"), + function(x, numPartitions, partitionFunc = hashCode) { + + #if (missing(partitionFunc)) { + # partitionFunc <- hashCode + #} + + partitionFunc <- cleanClosure(partitionFunc) + serializedHashFuncBytes <- serialize(partitionFunc, connection = NULL) + + packageNamesArr <- serialize(.sparkREnv$.packages, + connection = NULL) + broadcastArr <- lapply(ls(.broadcastNames), function(name) { + get(name, .broadcastNames) }) + jrdd <- getJRDD(x) + + # We create a PairwiseRRDD that extends RDD[(Array[Byte], + # Array[Byte])], where the key is the hashed split, the value is + # the content (key-val pairs). + pairwiseRRDD <- newJObject("org.apache.spark.api.r.PairwiseRRDD", + callJMethod(jrdd, "rdd"), + as.integer(numPartitions), + serializedHashFuncBytes, + getSerializedMode(x), + packageNamesArr, + as.character(.sparkREnv$libname), + broadcastArr, + callJMethod(jrdd, "classTag")) + + # Create a corresponding partitioner. + rPartitioner <- newJObject("org.apache.spark.HashPartitioner", + as.integer(numPartitions)) + + # Call partitionBy on the obtained PairwiseRDD. + javaPairRDD <- callJMethod(pairwiseRRDD, "asJavaPairRDD") + javaPairRDD <- callJMethod(javaPairRDD, "partitionBy", rPartitioner) + + # Call .values() on the result to get back the final result, the + # shuffled acutal content key-val pairs. + r <- callJMethod(javaPairRDD, "values") + + RDD(r, serializedMode = "byte") + }) + +#' Group values by key +#' +#' This function operates on RDDs where every element is of the form list(K, V) or c(K, V). +#' and group values for each key in the RDD into a single sequence. +#' +#' @param x The RDD to group. Should be an RDD where each element is +#' list(K, V) or c(K, V). +#' @param numPartitions Number of partitions to create. +#' @return An RDD where each element is list(K, list(V)) +#' @seealso reduceByKey +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' pairs <- list(list(1, 2), list(1.1, 3), list(1, 4)) +#' rdd <- parallelize(sc, pairs) +#' parts <- groupByKey(rdd, 2L) +#' grouped <- collect(parts) +#' grouped[[1]] # Should be a list(1, list(2, 4)) +#'} +#' @rdname groupByKey +#' @aliases groupByKey,RDD,integer-method +setMethod("groupByKey", + signature(x = "RDD", numPartitions = "integer"), + function(x, numPartitions) { + shuffled <- partitionBy(x, numPartitions) + groupVals <- function(part) { + vals <- new.env() + keys <- new.env() + pred <- function(item) exists(item$hash, keys) + appendList <- function(acc, i) { + addItemToAccumulator(acc, i) + acc + } + makeList <- function(i) { + acc <- initAccumulator() + addItemToAccumulator(acc, i) + acc + } + # Each item in the partition is list of (K, V) + lapply(part, + function(item) { + item$hash <- as.character(hashCode(item[[1]])) + updateOrCreatePair(item, keys, vals, pred, + appendList, makeList) + }) + # extract out data field + vals <- eapply(vals, + function(i) { + length(i$data) <- i$counter + i$data + }) + # Every key in the environment contains a list + # Convert that to list(K, Seq[V]) + convertEnvsToList(keys, vals) + } + lapplyPartition(shuffled, groupVals) + }) + +#' Merge values by key +#' +#' This function operates on RDDs where every element is of the form list(K, V) or c(K, V). +#' and merges the values for each key using an associative reduce function. +#' +#' @param x The RDD to reduce by key. Should be an RDD where each element is +#' list(K, V) or c(K, V). +#' @param combineFunc The associative reduce function to use. +#' @param numPartitions Number of partitions to create. +#' @return An RDD where each element is list(K, V') where V' is the merged +#' value +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' pairs <- list(list(1, 2), list(1.1, 3), list(1, 4)) +#' rdd <- parallelize(sc, pairs) +#' parts <- reduceByKey(rdd, "+", 2L) +#' reduced <- collect(parts) +#' reduced[[1]] # Should be a list(1, 6) +#'} +#' @rdname reduceByKey +#' @aliases reduceByKey,RDD,integer-method +setMethod("reduceByKey", + signature(x = "RDD", combineFunc = "ANY", numPartitions = "integer"), + function(x, combineFunc, numPartitions) { + reduceVals <- function(part) { + vals <- new.env() + keys <- new.env() + pred <- function(item) exists(item$hash, keys) + lapply(part, + function(item) { + item$hash <- as.character(hashCode(item[[1]])) + updateOrCreatePair(item, keys, vals, pred, combineFunc, identity) + }) + convertEnvsToList(keys, vals) + } + locallyReduced <- lapplyPartition(x, reduceVals) + shuffled <- partitionBy(locallyReduced, numPartitions) + lapplyPartition(shuffled, reduceVals) + }) + +#' Merge values by key locally +#' +#' This function operates on RDDs where every element is of the form list(K, V) or c(K, V). +#' and merges the values for each key using an associative reduce function, but return the +#' results immediately to the driver as an R list. +#' +#' @param x The RDD to reduce by key. Should be an RDD where each element is +#' list(K, V) or c(K, V). +#' @param combineFunc The associative reduce function to use. +#' @return A list of elements of type list(K, V') where V' is the merged value for each key +#' @seealso reduceByKey +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' pairs <- list(list(1, 2), list(1.1, 3), list(1, 4)) +#' rdd <- parallelize(sc, pairs) +#' reduced <- reduceByKeyLocally(rdd, "+") +#' reduced # list(list(1, 6), list(1.1, 3)) +#'} +#' @rdname reduceByKeyLocally +#' @aliases reduceByKeyLocally,RDD,integer-method +setMethod("reduceByKeyLocally", + signature(x = "RDD", combineFunc = "ANY"), + function(x, combineFunc) { + reducePart <- function(part) { + vals <- new.env() + keys <- new.env() + pred <- function(item) exists(item$hash, keys) + lapply(part, + function(item) { + item$hash <- as.character(hashCode(item[[1]])) + updateOrCreatePair(item, keys, vals, pred, combineFunc, identity) + }) + list(list(keys, vals)) # return hash to avoid re-compute in merge + } + mergeParts <- function(accum, x) { + pred <- function(item) { + exists(item$hash, accum[[1]]) + } + lapply(ls(x[[1]]), + function(name) { + item <- list(x[[1]][[name]], x[[2]][[name]]) + item$hash <- name + updateOrCreatePair(item, accum[[1]], accum[[2]], pred, combineFunc, identity) + }) + accum + } + reduced <- mapPartitions(x, reducePart) + merged <- reduce(reduced, mergeParts) + convertEnvsToList(merged[[1]], merged[[2]]) + }) + +#' Combine values by key +#' +#' Generic function to combine the elements for each key using a custom set of +#' aggregation functions. Turns an RDD[(K, V)] into a result of type RDD[(K, C)], +#' for a "combined type" C. Note that V and C can be different -- for example, one +#' might group an RDD of type (Int, Int) into an RDD of type (Int, Seq[Int]). + +#' Users provide three functions: +#' \itemize{ +#' \item createCombiner, which turns a V into a C (e.g., creates a one-element list) +#' \item mergeValue, to merge a V into a C (e.g., adds it to the end of a list) - +#' \item mergeCombiners, to combine two C's into a single one (e.g., concatentates +#' two lists). +#' } +#' +#' @param x The RDD to combine. Should be an RDD where each element is +#' list(K, V) or c(K, V). +#' @param createCombiner Create a combiner (C) given a value (V) +#' @param mergeValue Merge the given value (V) with an existing combiner (C) +#' @param mergeCombiners Merge two combiners and return a new combiner +#' @param numPartitions Number of partitions to create. +#' @return An RDD where each element is list(K, C) where C is the combined type +#' +#' @seealso groupByKey, reduceByKey +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' pairs <- list(list(1, 2), list(1.1, 3), list(1, 4)) +#' rdd <- parallelize(sc, pairs) +#' parts <- combineByKey(rdd, function(x) { x }, "+", "+", 2L) +#' combined <- collect(parts) +#' combined[[1]] # Should be a list(1, 6) +#'} +#' @rdname combineByKey +#' @aliases combineByKey,RDD,ANY,ANY,ANY,integer-method +setMethod("combineByKey", + signature(x = "RDD", createCombiner = "ANY", mergeValue = "ANY", + mergeCombiners = "ANY", numPartitions = "integer"), + function(x, createCombiner, mergeValue, mergeCombiners, numPartitions) { + combineLocally <- function(part) { + combiners <- new.env() + keys <- new.env() + pred <- function(item) exists(item$hash, keys) + lapply(part, + function(item) { + item$hash <- as.character(item[[1]]) + updateOrCreatePair(item, keys, combiners, pred, mergeValue, createCombiner) + }) + convertEnvsToList(keys, combiners) + } + locallyCombined <- lapplyPartition(x, combineLocally) + shuffled <- partitionBy(locallyCombined, numPartitions) + mergeAfterShuffle <- function(part) { + combiners <- new.env() + keys <- new.env() + pred <- function(item) exists(item$hash, keys) + lapply(part, + function(item) { + item$hash <- as.character(item[[1]]) + updateOrCreatePair(item, keys, combiners, pred, mergeCombiners, identity) + }) + convertEnvsToList(keys, combiners) + } + lapplyPartition(shuffled, mergeAfterShuffle) + }) + +#' Aggregate a pair RDD by each key. +#' +#' Aggregate the values of each key in an RDD, using given combine functions +#' and a neutral "zero value". This function can return a different result type, +#' U, than the type of the values in this RDD, V. Thus, we need one operation +#' for merging a V into a U and one operation for merging two U's, The former +#' operation is used for merging values within a partition, and the latter is +#' used for merging values between partitions. To avoid memory allocation, both +#' of these functions are allowed to modify and return their first argument +#' instead of creating a new U. +#' +#' @param x An RDD. +#' @param zeroValue A neutral "zero value". +#' @param seqOp A function to aggregate the values of each key. It may return +#' a different result type from the type of the values. +#' @param combOp A function to aggregate results of seqOp. +#' @return An RDD containing the aggregation result. +#' @seealso foldByKey, combineByKey +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(list(1, 1), list(1, 2), list(2, 3), list(2, 4))) +#' zeroValue <- list(0, 0) +#' seqOp <- function(x, y) { list(x[[1]] + y, x[[2]] + 1) } +#' combOp <- function(x, y) { list(x[[1]] + y[[1]], x[[2]] + y[[2]]) } +#' aggregateByKey(rdd, zeroValue, seqOp, combOp, 2L) +#' # list(list(1, list(3, 2)), list(2, list(7, 2))) +#'} +#' @rdname aggregateByKey +#' @aliases aggregateByKey,RDD,ANY,ANY,ANY,integer-method +setMethod("aggregateByKey", + signature(x = "RDD", zeroValue = "ANY", seqOp = "ANY", + combOp = "ANY", numPartitions = "integer"), + function(x, zeroValue, seqOp, combOp, numPartitions) { + createCombiner <- function(v) { + do.call(seqOp, list(zeroValue, v)) + } + + combineByKey(x, createCombiner, seqOp, combOp, numPartitions) + }) + +#' Fold a pair RDD by each key. +#' +#' Aggregate the values of each key in an RDD, using an associative function "func" +#' and a neutral "zero value" which may be added to the result an arbitrary +#' number of times, and must not change the result (e.g., 0 for addition, or +#' 1 for multiplication.). +#' +#' @param x An RDD. +#' @param zeroValue A neutral "zero value". +#' @param func An associative function for folding values of each key. +#' @return An RDD containing the aggregation result. +#' @seealso aggregateByKey, combineByKey +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(list(1, 1), list(1, 2), list(2, 3), list(2, 4))) +#' foldByKey(rdd, 0, "+", 2L) # list(list(1, 3), list(2, 7)) +#'} +#' @rdname foldByKey +#' @aliases foldByKey,RDD,ANY,ANY,integer-method +setMethod("foldByKey", + signature(x = "RDD", zeroValue = "ANY", + func = "ANY", numPartitions = "integer"), + function(x, zeroValue, func, numPartitions) { + aggregateByKey(x, zeroValue, func, func, numPartitions) + }) + +############ Binary Functions ############# + +#' Join two RDDs +#' +#' @description +#' \code{join} This function joins two RDDs where every element is of the form list(K, V). +#' The key types of the two RDDs should be the same. +#' +#' @param x An RDD to be joined. Should be an RDD where each element is +#' list(K, V). +#' @param y An RDD to be joined. Should be an RDD where each element is +#' list(K, V). +#' @param numPartitions Number of partitions to create. +#' @return a new RDD containing all pairs of elements with matching keys in +#' two input RDDs. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd1 <- parallelize(sc, list(list(1, 1), list(2, 4))) +#' rdd2 <- parallelize(sc, list(list(1, 2), list(1, 3))) +#' join(rdd1, rdd2, 2L) # list(list(1, list(1, 2)), list(1, list(1, 3)) +#'} +#' @rdname join-methods +#' @aliases join,RDD,RDD-method +setMethod("join", + signature(x = "RDD", y = "RDD"), + function(x, y, numPartitions) { + xTagged <- lapply(x, function(i) { list(i[[1]], list(1L, i[[2]])) }) + yTagged <- lapply(y, function(i) { list(i[[1]], list(2L, i[[2]])) }) + + doJoin <- function(v) { + joinTaggedList(v, list(FALSE, FALSE)) + } + + joined <- flatMapValues(groupByKey(unionRDD(xTagged, yTagged), numToInt(numPartitions)), + doJoin) + }) + +#' Left outer join two RDDs +#' +#' @description +#' \code{leftouterjoin} This function left-outer-joins two RDDs where every element is of the form list(K, V). +#' The key types of the two RDDs should be the same. +#' +#' @param x An RDD to be joined. Should be an RDD where each element is +#' list(K, V). +#' @param y An RDD to be joined. Should be an RDD where each element is +#' list(K, V). +#' @param numPartitions Number of partitions to create. +#' @return For each element (k, v) in x, the resulting RDD will either contain +#' all pairs (k, (v, w)) for (k, w) in rdd2, or the pair (k, (v, NULL)) +#' if no elements in rdd2 have key k. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd1 <- parallelize(sc, list(list(1, 1), list(2, 4))) +#' rdd2 <- parallelize(sc, list(list(1, 2), list(1, 3))) +#' leftOuterJoin(rdd1, rdd2, 2L) +#' # list(list(1, list(1, 2)), list(1, list(1, 3)), list(2, list(4, NULL))) +#'} +#' @rdname join-methods +#' @aliases leftOuterJoin,RDD,RDD-method +setMethod("leftOuterJoin", + signature(x = "RDD", y = "RDD", numPartitions = "integer"), + function(x, y, numPartitions) { + xTagged <- lapply(x, function(i) { list(i[[1]], list(1L, i[[2]])) }) + yTagged <- lapply(y, function(i) { list(i[[1]], list(2L, i[[2]])) }) + + doJoin <- function(v) { + joinTaggedList(v, list(FALSE, TRUE)) + } + + joined <- flatMapValues(groupByKey(unionRDD(xTagged, yTagged), numPartitions), doJoin) + }) + +#' Right outer join two RDDs +#' +#' @description +#' \code{rightouterjoin} This function right-outer-joins two RDDs where every element is of the form list(K, V). +#' The key types of the two RDDs should be the same. +#' +#' @param x An RDD to be joined. Should be an RDD where each element is +#' list(K, V). +#' @param y An RDD to be joined. Should be an RDD where each element is +#' list(K, V). +#' @param numPartitions Number of partitions to create. +#' @return For each element (k, w) in y, the resulting RDD will either contain +#' all pairs (k, (v, w)) for (k, v) in x, or the pair (k, (NULL, w)) +#' if no elements in x have key k. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd1 <- parallelize(sc, list(list(1, 2), list(1, 3))) +#' rdd2 <- parallelize(sc, list(list(1, 1), list(2, 4))) +#' rightOuterJoin(rdd1, rdd2, 2L) +#' # list(list(1, list(2, 1)), list(1, list(3, 1)), list(2, list(NULL, 4))) +#'} +#' @rdname join-methods +#' @aliases rightOuterJoin,RDD,RDD-method +setMethod("rightOuterJoin", + signature(x = "RDD", y = "RDD", numPartitions = "integer"), + function(x, y, numPartitions) { + xTagged <- lapply(x, function(i) { list(i[[1]], list(1L, i[[2]])) }) + yTagged <- lapply(y, function(i) { list(i[[1]], list(2L, i[[2]])) }) + + doJoin <- function(v) { + joinTaggedList(v, list(TRUE, FALSE)) + } + + joined <- flatMapValues(groupByKey(unionRDD(xTagged, yTagged), numPartitions), doJoin) + }) + +#' Full outer join two RDDs +#' +#' @description +#' \code{fullouterjoin} This function full-outer-joins two RDDs where every element is of the form list(K, V). +#' The key types of the two RDDs should be the same. +#' +#' @param x An RDD to be joined. Should be an RDD where each element is +#' list(K, V). +#' @param y An RDD to be joined. Should be an RDD where each element is +#' list(K, V). +#' @param numPartitions Number of partitions to create. +#' @return For each element (k, v) in x and (k, w) in y, the resulting RDD +#' will contain all pairs (k, (v, w)) for both (k, v) in x and +#' (k, w) in y, or the pair (k, (NULL, w))/(k, (v, NULL)) if no elements +#' in x/y have key k. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd1 <- parallelize(sc, list(list(1, 2), list(1, 3), list(3, 3))) +#' rdd2 <- parallelize(sc, list(list(1, 1), list(2, 4))) +#' fullOuterJoin(rdd1, rdd2, 2L) # list(list(1, list(2, 1)), +#' # list(1, list(3, 1)), +#' # list(2, list(NULL, 4))) +#' # list(3, list(3, NULL)), +#'} +#' @rdname join-methods +#' @aliases fullOuterJoin,RDD,RDD-method +setMethod("fullOuterJoin", + signature(x = "RDD", y = "RDD", numPartitions = "integer"), + function(x, y, numPartitions) { + xTagged <- lapply(x, function(i) { list(i[[1]], list(1L, i[[2]])) }) + yTagged <- lapply(y, function(i) { list(i[[1]], list(2L, i[[2]])) }) + + doJoin <- function(v) { + joinTaggedList(v, list(TRUE, TRUE)) + } + + joined <- flatMapValues(groupByKey(unionRDD(xTagged, yTagged), numPartitions), doJoin) + }) + +#' For each key k in several RDDs, return a resulting RDD that +#' whose values are a list of values for the key in all RDDs. +#' +#' @param ... Several RDDs. +#' @param numPartitions Number of partitions to create. +#' @return a new RDD containing all pairs of elements with values in a list +#' in all RDDs. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd1 <- parallelize(sc, list(list(1, 1), list(2, 4))) +#' rdd2 <- parallelize(sc, list(list(1, 2), list(1, 3))) +#' cogroup(rdd1, rdd2, numPartitions = 2L) +#' # list(list(1, list(1, list(2, 3))), list(2, list(list(4), list())) +#'} +#' @rdname cogroup +#' @aliases cogroup,RDD-method +setMethod("cogroup", + "RDD", + function(..., numPartitions) { + rdds <- list(...) + rddsLen <- length(rdds) + for (i in 1:rddsLen) { + rdds[[i]] <- lapply(rdds[[i]], + function(x) { list(x[[1]], list(i, x[[2]])) }) + # TODO(hao): As issue [SparkR-142] mentions, the right value of i + # will not be captured into UDF if getJRDD is not invoked. + # It should be resolved together with that issue. + getJRDD(rdds[[i]]) # Capture the closure. + } + union.rdd <- Reduce(unionRDD, rdds) + group.func <- function(vlist) { + res <- list() + length(res) <- rddsLen + for (x in vlist) { + i <- x[[1]] + acc <- res[[i]] + # Create an accumulator. + if (is.null(acc)) { + acc <- initAccumulator() + } + addItemToAccumulator(acc, x[[2]]) + res[[i]] <- acc + } + lapply(res, function(acc) { + if (is.null(acc)) { + list() + } else { + acc$data + } + }) + } + cogroup.rdd <- mapValues(groupByKey(union.rdd, numPartitions), + group.func) + }) + +#' Sort a (k, v) pair RDD by k. +#' +#' @param x A (k, v) pair RDD to be sorted. +#' @param ascending A flag to indicate whether the sorting is ascending or descending. +#' @param numPartitions Number of partitions to create. +#' @return An RDD where all (k, v) pair elements are sorted. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(list(3, 1), list(2, 2), list(1, 3))) +#' collect(sortByKey(rdd)) # list (list(1, 3), list(2, 2), list(3, 1)) +#'} +#' @rdname sortByKey +#' @aliases sortByKey,RDD,RDD-method +setMethod("sortByKey", + signature(x = "RDD"), + function(x, ascending = TRUE, numPartitions = SparkR::numPartitions(x)) { + rangeBounds <- list() + + if (numPartitions > 1) { + rddSize <- count(x) + # constant from Spark's RangePartitioner + maxSampleSize <- numPartitions * 20 + fraction <- min(maxSampleSize / max(rddSize, 1), 1.0) + + samples <- collect(keys(sampleRDD(x, FALSE, fraction, 1L))) + + # Note: the built-in R sort() function only works on atomic vectors + samples <- sort(unlist(samples, recursive = FALSE), decreasing = !ascending) + + if (length(samples) > 0) { + rangeBounds <- lapply(seq_len(numPartitions - 1), + function(i) { + j <- ceiling(length(samples) * i / numPartitions) + samples[j] + }) + } + } + + rangePartitionFunc <- function(key) { + partition <- 0 + + # TODO: Use binary search instead of linear search, similar with Spark + while (partition < length(rangeBounds) && key > rangeBounds[[partition + 1]]) { + partition <- partition + 1 + } + + if (ascending) { + partition + } else { + numPartitions - partition - 1 + } + } + + partitionFunc <- function(part) { + sortKeyValueList(part, decreasing = !ascending) + } + + newRDD <- partitionBy(x, numPartitions, rangePartitionFunc) + lapplyPartition(newRDD, partitionFunc) + }) + diff --git a/R/pkg/R/serialize.R b/R/pkg/R/serialize.R new file mode 100644 index 0000000000000..8a9c0c652ce24 --- /dev/null +++ b/R/pkg/R/serialize.R @@ -0,0 +1,195 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Utility functions to serialize R objects so they can be read in Java. + +# Type mapping from R to Java +# +# NULL -> Void +# integer -> Int +# character -> String +# logical -> Boolean +# double, numeric -> Double +# raw -> Array[Byte] +# Date -> Date +# POSIXct,POSIXlt -> Time +# +# list[T] -> Array[T], where T is one of above mentioned types +# environment -> Map[String, T], where T is a native type +# jobj -> Object, where jobj is an object created in the backend + +writeObject <- function(con, object, writeType = TRUE) { + # NOTE: In R vectors have same type as objects. So we don't support + # passing in vectors as arrays and instead require arrays to be passed + # as lists. + type <- class(object)[[1]] # class of POSIXlt is c("POSIXlt", "POSIXt") + if (writeType) { + writeType(con, type) + } + switch(type, + NULL = writeVoid(con), + integer = writeInt(con, object), + character = writeString(con, object), + logical = writeBoolean(con, object), + double = writeDouble(con, object), + numeric = writeDouble(con, object), + raw = writeRaw(con, object), + list = writeList(con, object), + jobj = writeJobj(con, object), + environment = writeEnv(con, object), + Date = writeDate(con, object), + POSIXlt = writeTime(con, object), + POSIXct = writeTime(con, object), + stop(paste("Unsupported type for serialization", type))) +} + +writeVoid <- function(con) { + # no value for NULL +} + +writeJobj <- function(con, value) { + if (!isValidJobj(value)) { + stop("invalid jobj ", value$id) + } + writeString(con, value$id) +} + +writeString <- function(con, value) { + writeInt(con, as.integer(nchar(value) + 1)) + writeBin(value, con, endian = "big") +} + +writeInt <- function(con, value) { + writeBin(as.integer(value), con, endian = "big") +} + +writeDouble <- function(con, value) { + writeBin(value, con, endian = "big") +} + +writeBoolean <- function(con, value) { + # TRUE becomes 1, FALSE becomes 0 + writeInt(con, as.integer(value)) +} + +writeRawSerialize <- function(outputCon, batch) { + outputSer <- serialize(batch, ascii = FALSE, connection = NULL) + writeRaw(outputCon, outputSer) +} + +writeRowSerialize <- function(outputCon, rows) { + invisible(lapply(rows, function(r) { + bytes <- serializeRow(r) + writeRaw(outputCon, bytes) + })) +} + +serializeRow <- function(row) { + rawObj <- rawConnection(raw(0), "wb") + on.exit(close(rawObj)) + writeRow(rawObj, row) + rawConnectionValue(rawObj) +} + +writeRow <- function(con, row) { + numCols <- length(row) + writeInt(con, numCols) + for (i in 1:numCols) { + writeObject(con, row[[i]]) + } +} + +writeRaw <- function(con, batch) { + writeInt(con, length(batch)) + writeBin(batch, con, endian = "big") +} + +writeType <- function(con, class) { + type <- switch(class, + NULL = "n", + integer = "i", + character = "c", + logical = "b", + double = "d", + numeric = "d", + raw = "r", + list = "l", + jobj = "j", + environment = "e", + Date = "D", + POSIXlt = 't', + POSIXct = 't', + stop(paste("Unsupported type for serialization", class))) + writeBin(charToRaw(type), con) +} + +# Used to pass arrays where all the elements are of the same type +writeList <- function(con, arr) { + # All elements should be of same type + elemType <- unique(sapply(arr, function(elem) { class(elem) })) + stopifnot(length(elemType) <= 1) + + # TODO: Empty lists are given type "character" right now. + # This may not work if the Java side expects array of any other type. + if (length(elemType) == 0) { + elemType <- class("somestring") + } + + writeType(con, elemType) + writeInt(con, length(arr)) + + if (length(arr) > 0) { + for (a in arr) { + writeObject(con, a, FALSE) + } + } +} + +# Used to pass in hash maps required on Java side. +writeEnv <- function(con, env) { + len <- length(env) + + writeInt(con, len) + if (len > 0) { + writeList(con, as.list(ls(env))) + vals <- lapply(ls(env), function(x) { env[[x]] }) + writeList(con, as.list(vals)) + } +} + +writeDate <- function(con, date) { + writeString(con, as.character(date)) +} + +writeTime <- function(con, time) { + writeDouble(con, as.double(time)) +} + +# Used to serialize in a list of objects where each +# object can be of a different type. Serialization format is +# for each object +writeArgs <- function(con, args) { + if (length(args) > 0) { + for (a in args) { + writeObject(con, a) + } + } +} + +writeStrings <- function(con, stringList) { + writeLines(unlist(stringList), con) +} diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R new file mode 100644 index 0000000000000..bc82df01f0fff --- /dev/null +++ b/R/pkg/R/sparkR.R @@ -0,0 +1,266 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +.sparkREnv <- new.env() + +sparkR.onLoad <- function(libname, pkgname) { + .sparkREnv$libname <- libname +} + +# Utility function that returns TRUE if we have an active connection to the +# backend and FALSE otherwise +connExists <- function(env) { + tryCatch({ + exists(".sparkRCon", envir = env) && isOpen(env[[".sparkRCon"]]) + }, error = function(err) { + return(FALSE) + }) +} + +#' Stop the Spark context. +#' +#' Also terminates the backend this R session is connected to +sparkR.stop <- function() { + env <- .sparkREnv + if (exists(".sparkRCon", envir = env)) { + # cat("Stopping SparkR\n") + if (exists(".sparkRjsc", envir = env)) { + sc <- get(".sparkRjsc", envir = env) + callJMethod(sc, "stop") + rm(".sparkRjsc", envir = env) + } + + if (exists(".backendLaunched", envir = env)) { + callJStatic("SparkRHandler", "stopBackend") + } + + # Also close the connection and remove it from our env + conn <- get(".sparkRCon", envir = env) + close(conn) + + rm(".sparkRCon", envir = env) + rm(".scStartTime", envir = env) + } + + if (exists(".monitorConn", envir = env)) { + conn <- get(".monitorConn", envir = env) + close(conn) + rm(".monitorConn", envir = env) + } + + # Clear all broadcast variables we have + # as the jobj will not be valid if we restart the JVM + clearBroadcastVariables() + + # Clear jobj maps + clearJobjs() +} + +#' Initialize a new Spark Context. +#' +#' This function initializes a new SparkContext. +#' +#' @param master The Spark master URL. +#' @param appName Application name to register with cluster manager +#' @param sparkHome Spark Home directory +#' @param sparkEnvir Named list of environment variables to set on worker nodes. +#' @param sparkExecutorEnv Named list of environment variables to be used when launching executors. +#' @param sparkJars Character string vector of jar files to pass to the worker nodes. +#' @param sparkRLibDir The path where R is installed on the worker nodes. +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init("local[2]", "SparkR", "/home/spark") +#' sc <- sparkR.init("local[2]", "SparkR", "/home/spark", +#' list(spark.executor.memory="1g")) +#' sc <- sparkR.init("yarn-client", "SparkR", "/home/spark", +#' list(spark.executor.memory="1g"), +#' list(LD_LIBRARY_PATH="/directory of JVM libraries (libjvm.so) on workers/"), +#' c("jarfile1.jar","jarfile2.jar")) +#'} + +sparkR.init <- function( + master = "", + appName = "SparkR", + sparkHome = Sys.getenv("SPARK_HOME"), + sparkEnvir = list(), + sparkExecutorEnv = list(), + sparkJars = "", + sparkRLibDir = "") { + + if (exists(".sparkRjsc", envir = .sparkREnv)) { + cat("Re-using existing Spark Context. Please stop SparkR with sparkR.stop() or restart R to create a new Spark Context\n") + return(get(".sparkRjsc", envir = .sparkREnv)) + } + + sparkMem <- Sys.getenv("SPARK_MEM", "512m") + jars <- suppressWarnings(normalizePath(as.character(sparkJars))) + + # Classpath separator is ";" on Windows + # URI needs four /// as from http://stackoverflow.com/a/18522792 + if (.Platform$OS.type == "unix") { + collapseChar <- ":" + uriSep <- "//" + } else { + collapseChar <- ";" + uriSep <- "////" + } + + existingPort <- Sys.getenv("EXISTING_SPARKR_BACKEND_PORT", "") + if (existingPort != "") { + backendPort <- existingPort + } else { + path <- tempfile(pattern = "backend_port") + launchBackend( + args = path, + sparkHome = sparkHome, + jars = jars, + sparkSubmitOpts = Sys.getenv("SPARKR_SUBMIT_ARGS", "sparkr-shell")) + # wait atmost 100 seconds for JVM to launch + wait <- 0.1 + for (i in 1:25) { + Sys.sleep(wait) + if (file.exists(path)) { + break + } + wait <- wait * 1.25 + } + if (!file.exists(path)) { + stop("JVM is not ready after 10 seconds") + } + f <- file(path, open='rb') + backendPort <- readInt(f) + monitorPort <- readInt(f) + close(f) + file.remove(path) + if (length(backendPort) == 0 || backendPort == 0 || + length(monitorPort) == 0 || monitorPort == 0) { + stop("JVM failed to launch") + } + assign(".monitorConn", socketConnection(port = monitorPort), envir = .sparkREnv) + assign(".backendLaunched", 1, envir = .sparkREnv) + } + + .sparkREnv$backendPort <- backendPort + tryCatch({ + connectBackend("localhost", backendPort) + }, error = function(err) { + stop("Failed to connect JVM\n") + }) + + if (nchar(sparkHome) != 0) { + sparkHome <- normalizePath(sparkHome) + } + + if (nchar(sparkRLibDir) != 0) { + .sparkREnv$libname <- sparkRLibDir + } + + sparkEnvirMap <- new.env() + for (varname in names(sparkEnvir)) { + sparkEnvirMap[[varname]] <- sparkEnvir[[varname]] + } + + sparkExecutorEnvMap <- new.env() + if (!any(names(sparkExecutorEnv) == "LD_LIBRARY_PATH")) { + sparkExecutorEnvMap[["LD_LIBRARY_PATH"]] <- paste0("$LD_LIBRARY_PATH:",Sys.getenv("LD_LIBRARY_PATH")) + } + for (varname in names(sparkExecutorEnv)) { + sparkExecutorEnvMap[[varname]] <- sparkExecutorEnv[[varname]] + } + + nonEmptyJars <- Filter(function(x) { x != "" }, jars) + localJarPaths <- sapply(nonEmptyJars, function(j) { utils::URLencode(paste("file:", uriSep, j, sep = "")) }) + + # Set the start time to identify jobjs + # Seconds resolution is good enough for this purpose, so use ints + assign(".scStartTime", as.integer(Sys.time()), envir = .sparkREnv) + + assign( + ".sparkRjsc", + callJStatic( + "org.apache.spark.api.r.RRDD", + "createSparkContext", + master, + appName, + as.character(sparkHome), + as.list(localJarPaths), + sparkEnvirMap, + sparkExecutorEnvMap), + envir = .sparkREnv + ) + + sc <- get(".sparkRjsc", envir = .sparkREnv) + + # Register a finalizer to sleep 1 seconds on R exit to make RStudio happy + reg.finalizer(.sparkREnv, function(x) { Sys.sleep(1) }, onexit = TRUE) + + sc +} + +#' Initialize a new SQLContext. +#' +#' This function creates a SparkContext from an existing JavaSparkContext and +#' then uses it to initialize a new SQLContext +#' +#' @param jsc The existing JavaSparkContext created with SparkR.init() +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#'} + +sparkRSQL.init <- function(jsc) { + if (exists(".sparkRSQLsc", envir = .sparkREnv)) { + return(get(".sparkRSQLsc", envir = .sparkREnv)) + } + + sqlCtx <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", + "createSQLContext", + jsc) + assign(".sparkRSQLsc", sqlCtx, envir = .sparkREnv) + sqlCtx +} + +#' Initialize a new HiveContext. +#' +#' This function creates a HiveContext from an existing JavaSparkContext +#' +#' @param jsc The existing JavaSparkContext created with SparkR.init() +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRHive.init(sc) +#'} + +sparkRHive.init <- function(jsc) { + if (exists(".sparkRHivesc", envir = .sparkREnv)) { + return(get(".sparkRHivesc", envir = .sparkREnv)) + } + + ssc <- callJMethod(jsc, "sc") + hiveCtx <- tryCatch({ + newJObject("org.apache.spark.sql.hive.HiveContext", ssc) + }, error = function(err) { + stop("Spark SQL is not built with Hive support") + }) + + assign(".sparkRHivesc", hiveCtx, envir = .sparkREnv) + hiveCtx +} diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R new file mode 100644 index 0000000000000..c337fb0751e72 --- /dev/null +++ b/R/pkg/R/utils.R @@ -0,0 +1,467 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Utilities and Helpers + +# Given a JList, returns an R list containing the same elements, the number +# of which is optionally upper bounded by `logicalUpperBound` (by default, +# return all elements). Takes care of deserializations and type conversions. +convertJListToRList <- function(jList, flatten, logicalUpperBound = NULL, + serializedMode = "byte") { + arrSize <- callJMethod(jList, "size") + + # Datasets with serializedMode == "string" (such as an RDD directly generated by textFile()): + # each partition is not dense-packed into one Array[Byte], and `arrSize` + # here corresponds to number of logical elements. Thus we can prune here. + if (serializedMode == "string" && !is.null(logicalUpperBound)) { + arrSize <- min(arrSize, logicalUpperBound) + } + + results <- if (arrSize > 0) { + lapply(0:(arrSize - 1), + function(index) { + obj <- callJMethod(jList, "get", as.integer(index)) + + # Assume it is either an R object or a Java obj ref. + if (inherits(obj, "jobj")) { + if (isInstanceOf(obj, "scala.Tuple2")) { + # JavaPairRDD[Array[Byte], Array[Byte]]. + + keyBytes = callJMethod(obj, "_1") + valBytes = callJMethod(obj, "_2") + res <- list(unserialize(keyBytes), + unserialize(valBytes)) + } else { + stop(paste("utils.R: convertJListToRList only supports", + "RDD[Array[Byte]] and", + "JavaPairRDD[Array[Byte], Array[Byte]] for now")) + } + } else { + if (inherits(obj, "raw")) { + if (serializedMode == "byte") { + # RDD[Array[Byte]]. `obj` is a whole partition. + res <- unserialize(obj) + # For serialized datasets, `obj` (and `rRaw`) here corresponds to + # one whole partition dense-packed together. We deserialize the + # whole partition first, then cap the number of elements to be returned. + } else if (serializedMode == "row") { + res <- readRowList(obj) + # For DataFrames that have been converted to RRDDs, we call readRowList + # which will read in each row of the RRDD as a list and deserialize + # each element. + flatten <<- FALSE + # Use global assignment to change the flatten flag. This means + # we don't have to worry about the default argument in other functions + # e.g. collect + } + # TODO: is it possible to distinguish element boundary so that we can + # unserialize only what we need? + if (!is.null(logicalUpperBound)) { + res <- head(res, n = logicalUpperBound) + } + } else { + # obj is of a primitive Java type, is simplified to R's + # corresponding type. + res <- list(obj) + } + } + res + }) + } else { + list() + } + + if (flatten) { + as.list(unlist(results, recursive = FALSE)) + } else { + as.list(results) + } +} + +# Returns TRUE if `name` refers to an RDD in the given environment `env` +isRDD <- function(name, env) { + obj <- get(name, envir = env) + inherits(obj, "RDD") +} + +#' Compute the hashCode of an object +#' +#' Java-style function to compute the hashCode for the given object. Returns +#' an integer value. +#' +#' @details +#' This only works for integer, numeric and character types right now. +#' +#' @param key the object to be hashed +#' @return the hash code as an integer +#' @export +#' @examples +#' hashCode(1L) # 1 +#' hashCode(1.0) # 1072693248 +#' hashCode("1") # 49 +hashCode <- function(key) { + if (class(key) == "integer") { + as.integer(key[[1]]) + } else if (class(key) == "numeric") { + # Convert the double to long and then calculate the hash code + rawVec <- writeBin(key[[1]], con = raw()) + intBits <- packBits(rawToBits(rawVec), "integer") + as.integer(bitwXor(intBits[2], intBits[1])) + } else if (class(key) == "character") { + .Call("stringHashCode", key) + } else { + warning(paste("Could not hash object, returning 0", sep = "")) + as.integer(0) + } +} + +# Create a new RDD with serializedMode == "byte". +# Return itself if already in "byte" format. +serializeToBytes <- function(rdd) { + if (!inherits(rdd, "RDD")) { + stop("Argument 'rdd' is not an RDD type.") + } + if (getSerializedMode(rdd) != "byte") { + ser.rdd <- lapply(rdd, function(x) { x }) + return(ser.rdd) + } else { + return(rdd) + } +} + +# Create a new RDD with serializedMode == "string". +# Return itself if already in "string" format. +serializeToString <- function(rdd) { + if (!inherits(rdd, "RDD")) { + stop("Argument 'rdd' is not an RDD type.") + } + if (getSerializedMode(rdd) != "string") { + ser.rdd <- lapply(rdd, function(x) { toString(x) }) + # force it to create jrdd using "string" + getJRDD(ser.rdd, serializedMode = "string") + return(ser.rdd) + } else { + return(rdd) + } +} + +# Fast append to list by using an accumulator. +# http://stackoverflow.com/questions/17046336/here-we-go-again-append-an-element-to-a-list-in-r +# +# The accumulator should has three fields size, counter and data. +# This function amortizes the allocation cost by doubling +# the size of the list every time it fills up. +addItemToAccumulator <- function(acc, item) { + if(acc$counter == acc$size) { + acc$size <- acc$size * 2 + length(acc$data) <- acc$size + } + acc$counter <- acc$counter + 1 + acc$data[[acc$counter]] <- item +} + +initAccumulator <- function() { + acc <- new.env() + acc$counter <- 0 + acc$data <- list(NULL) + acc$size <- 1 + acc +} + +# Utility function to sort a list of key value pairs +# Used in unit tests +sortKeyValueList <- function(kv_list, decreasing = FALSE) { + keys <- sapply(kv_list, function(x) x[[1]]) + kv_list[order(keys, decreasing = decreasing)] +} + +# Utility function to generate compact R lists from grouped rdd +# Used in Join-family functions +# param: +# tagged_list R list generated via groupByKey with tags(1L, 2L, ...) +# cnull Boolean list where each element determines whether the corresponding list should +# be converted to list(NULL) +genCompactLists <- function(tagged_list, cnull) { + len <- length(tagged_list) + lists <- list(vector("list", len), vector("list", len)) + index <- list(1, 1) + + for (x in tagged_list) { + tag <- x[[1]] + idx <- index[[tag]] + lists[[tag]][[idx]] <- x[[2]] + index[[tag]] <- idx + 1 + } + + len <- lapply(index, function(x) x - 1) + for (i in (1:2)) { + if (cnull[[i]] && len[[i]] == 0) { + lists[[i]] <- list(NULL) + } else { + length(lists[[i]]) <- len[[i]] + } + } + + lists +} + +# Utility function to merge compact R lists +# Used in Join-family functions +# param: +# left/right Two compact lists ready for Cartesian product +mergeCompactLists <- function(left, right) { + result <- list() + length(result) <- length(left) * length(right) + index <- 1 + for (i in left) { + for (j in right) { + result[[index]] <- list(i, j) + index <- index + 1 + } + } + result +} + +# Utility function to wrapper above two operations +# Used in Join-family functions +# param (same as genCompactLists): +# tagged_list R list generated via groupByKey with tags(1L, 2L, ...) +# cnull Boolean list where each element determines whether the corresponding list should +# be converted to list(NULL) +joinTaggedList <- function(tagged_list, cnull) { + lists <- genCompactLists(tagged_list, cnull) + mergeCompactLists(lists[[1]], lists[[2]]) +} + +# Utility function to reduce a key-value list with predicate +# Used in *ByKey functions +# param +# pair key-value pair +# keys/vals env of key/value with hashes +# updateOrCreatePred predicate function +# updateFn update or merge function for existing pair, similar with `mergeVal` @combineByKey +# createFn create function for new pair, similar with `createCombiner` @combinebykey +updateOrCreatePair <- function(pair, keys, vals, updateOrCreatePred, updateFn, createFn) { + # assume hashVal bind to `$hash`, key/val with index 1/2 + hashVal <- pair$hash + key <- pair[[1]] + val <- pair[[2]] + if (updateOrCreatePred(pair)) { + assign(hashVal, do.call(updateFn, list(get(hashVal, envir = vals), val)), envir = vals) + } else { + assign(hashVal, do.call(createFn, list(val)), envir = vals) + assign(hashVal, key, envir = keys) + } +} + +# Utility function to convert key&values envs into key-val list +convertEnvsToList <- function(keys, vals) { + lapply(ls(keys), + function(name) { + list(keys[[name]], vals[[name]]) + }) +} + +# Utility function to capture the varargs into environment object +varargsToEnv <- function(...) { + pairs <- as.list(substitute(list(...)))[-1L] + env <- new.env() + for (name in names(pairs)) { + env[[name]] <- pairs[[name]] + } + env +} + +getStorageLevel <- function(newLevel = c("DISK_ONLY", + "DISK_ONLY_2", + "MEMORY_AND_DISK", + "MEMORY_AND_DISK_2", + "MEMORY_AND_DISK_SER", + "MEMORY_AND_DISK_SER_2", + "MEMORY_ONLY", + "MEMORY_ONLY_2", + "MEMORY_ONLY_SER", + "MEMORY_ONLY_SER_2", + "OFF_HEAP")) { + match.arg(newLevel) + storageLevel <- switch(newLevel, + "DISK_ONLY" = callJStatic("org.apache.spark.storage.StorageLevel", "DISK_ONLY"), + "DISK_ONLY_2" = callJStatic("org.apache.spark.storage.StorageLevel", "DISK_ONLY_2"), + "MEMORY_AND_DISK" = callJStatic("org.apache.spark.storage.StorageLevel", "MEMORY_AND_DISK"), + "MEMORY_AND_DISK_2" = callJStatic("org.apache.spark.storage.StorageLevel", "MEMORY_AND_DISK_2"), + "MEMORY_AND_DISK_SER" = callJStatic("org.apache.spark.storage.StorageLevel", "MEMORY_AND_DISK_SER"), + "MEMORY_AND_DISK_SER_2" = callJStatic("org.apache.spark.storage.StorageLevel", "MEMORY_AND_DISK_SER_2"), + "MEMORY_ONLY" = callJStatic("org.apache.spark.storage.StorageLevel", "MEMORY_ONLY"), + "MEMORY_ONLY_2" = callJStatic("org.apache.spark.storage.StorageLevel", "MEMORY_ONLY_2"), + "MEMORY_ONLY_SER" = callJStatic("org.apache.spark.storage.StorageLevel", "MEMORY_ONLY_SER"), + "MEMORY_ONLY_SER_2" = callJStatic("org.apache.spark.storage.StorageLevel", "MEMORY_ONLY_SER_2"), + "OFF_HEAP" = callJStatic("org.apache.spark.storage.StorageLevel", "OFF_HEAP")) +} + +# Utility function for functions where an argument needs to be integer but we want to allow +# the user to type (for example) `5` instead of `5L` to avoid a confusing error message. +numToInt <- function(num) { + if (as.integer(num) != num) { + warning(paste("Coercing", as.list(sys.call())[[2]], "to integer.")) + } + as.integer(num) +} + +# create a Seq in JVM +toSeq <- function(...) { + callJStatic("org.apache.spark.sql.api.r.SQLUtils", "toSeq", list(...)) +} + +# create a Seq in JVM from a list +listToSeq <- function(l) { + callJStatic("org.apache.spark.sql.api.r.SQLUtils", "toSeq", l) +} + +# Utility function to recursively traverse the Abstract Syntax Tree (AST) of a +# user defined function (UDF), and to examine variables in the UDF to decide +# if their values should be included in the new function environment. +# param +# node The current AST node in the traversal. +# oldEnv The original function environment. +# defVars An Accumulator of variables names defined in the function's calling environment, +# including function argument and local variable names. +# checkedFunc An environment of function objects examined during cleanClosure. It can +# be considered as a "name"-to-"list of functions" mapping. +# newEnv A new function environment to store necessary function dependencies, an output argument. +processClosure <- function(node, oldEnv, defVars, checkedFuncs, newEnv) { + nodeLen <- length(node) + + if (nodeLen > 1 && typeof(node) == "language") { + # Recursive case: current AST node is an internal node, check for its children. + if (length(node[[1]]) > 1) { + for (i in 1:nodeLen) { + processClosure(node[[i]], oldEnv, defVars, checkedFuncs, newEnv) + } + } else { # if node[[1]] is length of 1, check for some R special functions. + nodeChar <- as.character(node[[1]]) + if (nodeChar == "{" || nodeChar == "(") { # Skip start symbol. + for (i in 2:nodeLen) { + processClosure(node[[i]], oldEnv, defVars, checkedFuncs, newEnv) + } + } else if (nodeChar == "<-" || nodeChar == "=" || + nodeChar == "<<-") { # Assignment Ops. + defVar <- node[[2]] + if (length(defVar) == 1 && typeof(defVar) == "symbol") { + # Add the defined variable name into defVars. + addItemToAccumulator(defVars, as.character(defVar)) + } else { + processClosure(node[[2]], oldEnv, defVars, checkedFuncs, newEnv) + } + for (i in 3:nodeLen) { + processClosure(node[[i]], oldEnv, defVars, checkedFuncs, newEnv) + } + } else if (nodeChar == "function") { # Function definition. + # Add parameter names. + newArgs <- names(node[[2]]) + lapply(newArgs, function(arg) { addItemToAccumulator(defVars, arg) }) + for (i in 3:nodeLen) { + processClosure(node[[i]], oldEnv, defVars, checkedFuncs, newEnv) + } + } else if (nodeChar == "$") { # Skip the field. + processClosure(node[[2]], oldEnv, defVars, checkedFuncs, newEnv) + } else if (nodeChar == "::" || nodeChar == ":::") { + processClosure(node[[3]], oldEnv, defVars, checkedFuncs, newEnv) + } else { + for (i in 1:nodeLen) { + processClosure(node[[i]], oldEnv, defVars, checkedFuncs, newEnv) + } + } + } + } else if (nodeLen == 1 && + (typeof(node) == "symbol" || typeof(node) == "language")) { + # Base case: current AST node is a leaf node and a symbol or a function call. + nodeChar <- as.character(node) + if (!nodeChar %in% defVars$data) { # Not a function parameter or local variable. + func.env <- oldEnv + topEnv <- parent.env(.GlobalEnv) + # Search in function environment, and function's enclosing environments + # up to global environment. There is no need to look into package environments + # above the global or namespace environment that is not SparkR below the global, + # as they are assumed to be loaded on workers. + while (!identical(func.env, topEnv)) { + # Namespaces other than "SparkR" will not be searched. + if (!isNamespace(func.env) || + (getNamespaceName(func.env) == "SparkR" && + !(nodeChar %in% getNamespaceExports("SparkR")))) { # Only include SparkR internals. + # Set parameter 'inherits' to FALSE since we do not need to search in + # attached package environments. + if (tryCatch(exists(nodeChar, envir = func.env, inherits = FALSE), + error = function(e) { FALSE })) { + obj <- get(nodeChar, envir = func.env, inherits = FALSE) + if (is.function(obj)) { # If the node is a function call. + funcList <- mget(nodeChar, envir = checkedFuncs, inherits = F, + ifnotfound = list(list(NULL)))[[1]] + found <- sapply(funcList, function(func) { + ifelse(identical(func, obj), TRUE, FALSE) + }) + if (sum(found) > 0) { # If function has been examined, ignore. + break + } + # Function has not been examined, record it and recursively clean its closure. + assign(nodeChar, + if (is.null(funcList[[1]])) { + list(obj) + } else { + append(funcList, obj) + }, + envir = checkedFuncs) + obj <- cleanClosure(obj, checkedFuncs) + } + assign(nodeChar, obj, envir = newEnv) + break + } + } + + # Continue to search in enclosure. + func.env <- parent.env(func.env) + } + } + } +} + +# Utility function to get user defined function (UDF) dependencies (closure). +# More specifically, this function captures the values of free variables defined +# outside a UDF, and stores them in the function's environment. +# param +# func A function whose closure needs to be captured. +# checkedFunc An environment of function objects examined during cleanClosure. It can be +# considered as a "name"-to-"list of functions" mapping. +# return value +# a new version of func that has an correct environment (closure). +cleanClosure <- function(func, checkedFuncs = new.env()) { + if (is.function(func)) { + newEnv <- new.env(parent = .GlobalEnv) + func.body <- body(func) + oldEnv <- environment(func) + # defVars is an Accumulator of variables names defined in the function's calling + # environment. First, function's arguments are added to defVars. + defVars <- initAccumulator() + argNames <- names(as.list(args(func))) + for (i in 1:(length(argNames) - 1)) { # Remove the ending NULL in pairlist. + addItemToAccumulator(defVars, argNames[i]) + } + # Recursively examine variables in the function body. + processClosure(func.body, oldEnv, defVars, checkedFuncs, newEnv) + environment(func) <- newEnv + } + func +} diff --git a/R/pkg/R/zzz.R b/R/pkg/R/zzz.R new file mode 100644 index 0000000000000..80d796d467943 --- /dev/null +++ b/R/pkg/R/zzz.R @@ -0,0 +1,21 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +.onLoad <- function(libname, pkgname) { + sparkR.onLoad(libname, pkgname) +} + diff --git a/R/pkg/inst/profile/general.R b/R/pkg/inst/profile/general.R new file mode 100644 index 0000000000000..8fe711b622086 --- /dev/null +++ b/R/pkg/inst/profile/general.R @@ -0,0 +1,22 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +.First <- function() { + home <- Sys.getenv("SPARK_HOME") + .libPaths(c(file.path(home, "R", "lib"), .libPaths())) + Sys.setenv(NOAWT=1) +} diff --git a/R/pkg/inst/profile/shell.R b/R/pkg/inst/profile/shell.R new file mode 100644 index 0000000000000..7a7f2031152a0 --- /dev/null +++ b/R/pkg/inst/profile/shell.R @@ -0,0 +1,31 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +.First <- function() { + home <- Sys.getenv("SPARK_HOME") + .libPaths(c(file.path(home, "R", "lib"), .libPaths())) + Sys.setenv(NOAWT=1) + + library(utils) + library(SparkR) + sc <- sparkR.init(Sys.getenv("MASTER", unset = "")) + assign("sc", sc, envir=.GlobalEnv) + sqlCtx <- sparkRSQL.init(sc) + assign("sqlCtx", sqlCtx, envir=.GlobalEnv) + cat("\n Welcome to SparkR!") + cat("\n Spark context is available as sc, SQL context is available as sqlCtx\n") +} diff --git a/R/pkg/inst/tests/test_binaryFile.R b/R/pkg/inst/tests/test_binaryFile.R new file mode 100644 index 0000000000000..4bb5f58d83dc9 --- /dev/null +++ b/R/pkg/inst/tests/test_binaryFile.R @@ -0,0 +1,90 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +context("functions on binary files") + +# JavaSparkContext handle +sc <- sparkR.init() + +mockFile = c("Spark is pretty.", "Spark is awesome.") + +test_that("saveAsObjectFile()/objectFile() following textFile() works", { + fileName1 <- tempfile(pattern="spark-test", fileext=".tmp") + fileName2 <- tempfile(pattern="spark-test", fileext=".tmp") + writeLines(mockFile, fileName1) + + rdd <- textFile(sc, fileName1) + saveAsObjectFile(rdd, fileName2) + rdd <- objectFile(sc, fileName2) + expect_equal(collect(rdd), as.list(mockFile)) + + unlink(fileName1) + unlink(fileName2, recursive = TRUE) +}) + +test_that("saveAsObjectFile()/objectFile() works on a parallelized list", { + fileName <- tempfile(pattern="spark-test", fileext=".tmp") + + l <- list(1, 2, 3) + rdd <- parallelize(sc, l) + saveAsObjectFile(rdd, fileName) + rdd <- objectFile(sc, fileName) + expect_equal(collect(rdd), l) + + unlink(fileName, recursive = TRUE) +}) + +test_that("saveAsObjectFile()/objectFile() following RDD transformations works", { + fileName1 <- tempfile(pattern="spark-test", fileext=".tmp") + fileName2 <- tempfile(pattern="spark-test", fileext=".tmp") + writeLines(mockFile, fileName1) + + rdd <- textFile(sc, fileName1) + + words <- flatMap(rdd, function(line) { strsplit(line, " ")[[1]] }) + wordCount <- lapply(words, function(word) { list(word, 1L) }) + + counts <- reduceByKey(wordCount, "+", 2L) + + saveAsObjectFile(counts, fileName2) + counts <- objectFile(sc, fileName2) + + output <- collect(counts) + expected <- list(list("awesome.", 1), list("Spark", 2), list("pretty.", 1), + list("is", 2)) + expect_equal(sortKeyValueList(output), sortKeyValueList(expected)) + + unlink(fileName1) + unlink(fileName2, recursive = TRUE) +}) + +test_that("saveAsObjectFile()/objectFile() works with multiple paths", { + fileName1 <- tempfile(pattern="spark-test", fileext=".tmp") + fileName2 <- tempfile(pattern="spark-test", fileext=".tmp") + + rdd1 <- parallelize(sc, "Spark is pretty.") + saveAsObjectFile(rdd1, fileName1) + rdd2 <- parallelize(sc, "Spark is awesome.") + saveAsObjectFile(rdd2, fileName2) + + rdd <- objectFile(sc, c(fileName1, fileName2)) + expect_true(count(rdd) == 2) + + unlink(fileName1, recursive = TRUE) + unlink(fileName2, recursive = TRUE) +}) + diff --git a/R/pkg/inst/tests/test_binary_function.R b/R/pkg/inst/tests/test_binary_function.R new file mode 100644 index 0000000000000..c15553ba28517 --- /dev/null +++ b/R/pkg/inst/tests/test_binary_function.R @@ -0,0 +1,68 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +context("binary functions") + +# JavaSparkContext handle +sc <- sparkR.init() + +# Data +nums <- 1:10 +rdd <- parallelize(sc, nums, 2L) + +# File content +mockFile <- c("Spark is pretty.", "Spark is awesome.") + +test_that("union on two RDDs", { + actual <- collect(unionRDD(rdd, rdd)) + expect_equal(actual, as.list(rep(nums, 2))) + + fileName <- tempfile(pattern="spark-test", fileext=".tmp") + writeLines(mockFile, fileName) + + text.rdd <- textFile(sc, fileName) + union.rdd <- unionRDD(rdd, text.rdd) + actual <- collect(union.rdd) + expect_equal(actual, c(as.list(nums), mockFile)) + expect_true(getSerializedMode(union.rdd) == "byte") + + rdd<- map(text.rdd, function(x) {x}) + union.rdd <- unionRDD(rdd, text.rdd) + actual <- collect(union.rdd) + expect_equal(actual, as.list(c(mockFile, mockFile))) + expect_true(getSerializedMode(union.rdd) == "byte") + + unlink(fileName) +}) + +test_that("cogroup on two RDDs", { + rdd1 <- parallelize(sc, list(list(1, 1), list(2, 4))) + rdd2 <- parallelize(sc, list(list(1, 2), list(1, 3))) + cogroup.rdd <- cogroup(rdd1, rdd2, numPartitions = 2L) + actual <- collect(cogroup.rdd) + expect_equal(actual, + list(list(1, list(list(1), list(2, 3))), list(2, list(list(4), list())))) + + rdd1 <- parallelize(sc, list(list("a", 1), list("a", 4))) + rdd2 <- parallelize(sc, list(list("b", 2), list("a", 3))) + cogroup.rdd <- cogroup(rdd1, rdd2, numPartitions = 2L) + actual <- collect(cogroup.rdd) + + expected <- list(list("b", list(list(), list(2))), list("a", list(list(1, 4), list(3)))) + expect_equal(sortKeyValueList(actual), + sortKeyValueList(expected)) +}) diff --git a/R/pkg/inst/tests/test_broadcast.R b/R/pkg/inst/tests/test_broadcast.R new file mode 100644 index 0000000000000..fee91a427d6d5 --- /dev/null +++ b/R/pkg/inst/tests/test_broadcast.R @@ -0,0 +1,48 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +context("broadcast variables") + +# JavaSparkContext handle +sc <- sparkR.init() + +# Partitioned data +nums <- 1:2 +rrdd <- parallelize(sc, nums, 2L) + +test_that("using broadcast variable", { + randomMat <- matrix(nrow=10, ncol=10, data=rnorm(100)) + randomMatBr <- broadcast(sc, randomMat) + + useBroadcast <- function(x) { + sum(value(randomMatBr) * x) + } + actual <- collect(lapply(rrdd, useBroadcast)) + expected <- list(sum(randomMat) * 1, sum(randomMat) * 2) + expect_equal(actual, expected) +}) + +test_that("without using broadcast variable", { + randomMat <- matrix(nrow=10, ncol=10, data=rnorm(100)) + + useBroadcast <- function(x) { + sum(randomMat * x) + } + actual <- collect(lapply(rrdd, useBroadcast)) + expected <- list(sum(randomMat) * 1, sum(randomMat) * 2) + expect_equal(actual, expected) +}) diff --git a/R/pkg/inst/tests/test_context.R b/R/pkg/inst/tests/test_context.R new file mode 100644 index 0000000000000..e4aab37436a74 --- /dev/null +++ b/R/pkg/inst/tests/test_context.R @@ -0,0 +1,50 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +context("test functions in sparkR.R") + +test_that("repeatedly starting and stopping SparkR", { + for (i in 1:4) { + sc <- sparkR.init() + rdd <- parallelize(sc, 1:20, 2L) + expect_equal(count(rdd), 20) + sparkR.stop() + } +}) + +test_that("rdd GC across sparkR.stop", { + sparkR.stop() + sc <- sparkR.init() # sc should get id 0 + rdd1 <- parallelize(sc, 1:20, 2L) # rdd1 should get id 1 + rdd2 <- parallelize(sc, 1:10, 2L) # rdd2 should get id 2 + sparkR.stop() + + sc <- sparkR.init() # sc should get id 0 again + + # GC rdd1 before creating rdd3 and rdd2 after + rm(rdd1) + gc() + + rdd3 <- parallelize(sc, 1:20, 2L) # rdd3 should get id 1 now + rdd4 <- parallelize(sc, 1:10, 2L) # rdd4 should get id 2 now + + rm(rdd2) + gc() + + count(rdd3) + count(rdd4) +}) diff --git a/R/pkg/inst/tests/test_includePackage.R b/R/pkg/inst/tests/test_includePackage.R new file mode 100644 index 0000000000000..8152b448d0870 --- /dev/null +++ b/R/pkg/inst/tests/test_includePackage.R @@ -0,0 +1,57 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +context("include R packages") + +# JavaSparkContext handle +sc <- sparkR.init() + +# Partitioned data +nums <- 1:2 +rdd <- parallelize(sc, nums, 2L) + +test_that("include inside function", { + # Only run the test if plyr is installed. + if ("plyr" %in% rownames(installed.packages())) { + suppressPackageStartupMessages(library(plyr)) + generateData <- function(x) { + suppressPackageStartupMessages(library(plyr)) + attach(airquality) + result <- transform(Ozone, logOzone = log(Ozone)) + result + } + + data <- lapplyPartition(rdd, generateData) + actual <- collect(data) + } +}) + +test_that("use include package", { + # Only run the test if plyr is installed. + if ("plyr" %in% rownames(installed.packages())) { + suppressPackageStartupMessages(library(plyr)) + generateData <- function(x) { + attach(airquality) + result <- transform(Ozone, logOzone = log(Ozone)) + result + } + + includePackage(sc, plyr) + data <- lapplyPartition(rdd, generateData) + actual <- collect(data) + } +}) diff --git a/R/pkg/inst/tests/test_parallelize_collect.R b/R/pkg/inst/tests/test_parallelize_collect.R new file mode 100644 index 0000000000000..fff028657db37 --- /dev/null +++ b/R/pkg/inst/tests/test_parallelize_collect.R @@ -0,0 +1,109 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +context("parallelize() and collect()") + +# Mock data +numVector <- c(-10:97) +numList <- list(sqrt(1), sqrt(2), sqrt(3), 4 ** 10) +strVector <- c("Dexter Morgan: I suppose I should be upset, even feel", + "violated, but I'm not. No, in fact, I think this is a friendly", + "message, like \"Hey, wanna play?\" and yes, I want to play. ", + "I really, really do.") +strList <- list("Dexter Morgan: Blood. Sometimes it sets my teeth on edge, ", + "other times it helps me control the chaos.", + "Dexter Morgan: Harry and Dorris Morgan did a wonderful job ", + "raising me. But they're both dead now. I didn't kill them. Honest.") + +numPairs <- list(list(1, 1), list(1, 2), list(2, 2), list(2, 3)) +strPairs <- list(list(strList, strList), list(strList, strList)) + +# JavaSparkContext handle +jsc <- sparkR.init() + +# Tests + +test_that("parallelize() on simple vectors and lists returns an RDD", { + numVectorRDD <- parallelize(jsc, numVector, 1) + numVectorRDD2 <- parallelize(jsc, numVector, 10) + numListRDD <- parallelize(jsc, numList, 1) + numListRDD2 <- parallelize(jsc, numList, 4) + strVectorRDD <- parallelize(jsc, strVector, 2) + strVectorRDD2 <- parallelize(jsc, strVector, 3) + strListRDD <- parallelize(jsc, strList, 4) + strListRDD2 <- parallelize(jsc, strList, 1) + + rdds <- c(numVectorRDD, + numVectorRDD2, + numListRDD, + numListRDD2, + strVectorRDD, + strVectorRDD2, + strListRDD, + strListRDD2) + + for (rdd in rdds) { + expect_true(inherits(rdd, "RDD")) + expect_true(.hasSlot(rdd, "jrdd") + && inherits(rdd@jrdd, "jobj") + && isInstanceOf(rdd@jrdd, "org.apache.spark.api.java.JavaRDD")) + } +}) + +test_that("collect(), following a parallelize(), gives back the original collections", { + numVectorRDD <- parallelize(jsc, numVector, 10) + expect_equal(collect(numVectorRDD), as.list(numVector)) + + numListRDD <- parallelize(jsc, numList, 1) + numListRDD2 <- parallelize(jsc, numList, 4) + expect_equal(collect(numListRDD), as.list(numList)) + expect_equal(collect(numListRDD2), as.list(numList)) + + strVectorRDD <- parallelize(jsc, strVector, 2) + strVectorRDD2 <- parallelize(jsc, strVector, 3) + expect_equal(collect(strVectorRDD), as.list(strVector)) + expect_equal(collect(strVectorRDD2), as.list(strVector)) + + strListRDD <- parallelize(jsc, strList, 4) + strListRDD2 <- parallelize(jsc, strList, 1) + expect_equal(collect(strListRDD), as.list(strList)) + expect_equal(collect(strListRDD2), as.list(strList)) +}) + +test_that("regression: collect() following a parallelize() does not drop elements", { + # 10 %/% 6 = 1, ceiling(10 / 6) = 2 + collLen <- 10 + numPart <- 6 + expected <- runif(collLen) + actual <- collect(parallelize(jsc, expected, numPart)) + expect_equal(actual, as.list(expected)) +}) + +test_that("parallelize() and collect() work for lists of pairs (pairwise data)", { + # use the pairwise logical to indicate pairwise data + numPairsRDDD1 <- parallelize(jsc, numPairs, 1) + numPairsRDDD2 <- parallelize(jsc, numPairs, 2) + numPairsRDDD3 <- parallelize(jsc, numPairs, 3) + expect_equal(collect(numPairsRDDD1), numPairs) + expect_equal(collect(numPairsRDDD2), numPairs) + expect_equal(collect(numPairsRDDD3), numPairs) + # can also leave out the parameter name, if the params are supplied in order + strPairsRDDD1 <- parallelize(jsc, strPairs, 1) + strPairsRDDD2 <- parallelize(jsc, strPairs, 2) + expect_equal(collect(strPairsRDDD1), strPairs) + expect_equal(collect(strPairsRDDD2), strPairs) +}) diff --git a/R/pkg/inst/tests/test_rdd.R b/R/pkg/inst/tests/test_rdd.R new file mode 100644 index 0000000000000..f75e0817b9406 --- /dev/null +++ b/R/pkg/inst/tests/test_rdd.R @@ -0,0 +1,644 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +context("basic RDD functions") + +# JavaSparkContext handle +sc <- sparkR.init() + +# Data +nums <- 1:10 +rdd <- parallelize(sc, nums, 2L) + +intPairs <- list(list(1L, -1), list(2L, 100), list(2L, 1), list(1L, 200)) +intRdd <- parallelize(sc, intPairs, 2L) + +test_that("get number of partitions in RDD", { + expect_equal(numPartitions(rdd), 2) + expect_equal(numPartitions(intRdd), 2) +}) + +test_that("first on RDD", { + expect_true(first(rdd) == 1) + newrdd <- lapply(rdd, function(x) x + 1) + expect_true(first(newrdd) == 2) +}) + +test_that("count and length on RDD", { + expect_equal(count(rdd), 10) + expect_equal(length(rdd), 10) +}) + +test_that("count by values and keys", { + mods <- lapply(rdd, function(x) { x %% 3 }) + actual <- countByValue(mods) + expected <- list(list(0, 3L), list(1, 4L), list(2, 3L)) + expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) + + actual <- countByKey(intRdd) + expected <- list(list(2L, 2L), list(1L, 2L)) + expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) +}) + +test_that("lapply on RDD", { + multiples <- lapply(rdd, function(x) { 2 * x }) + actual <- collect(multiples) + expect_equal(actual, as.list(nums * 2)) +}) + +test_that("lapplyPartition on RDD", { + sums <- lapplyPartition(rdd, function(part) { sum(unlist(part)) }) + actual <- collect(sums) + expect_equal(actual, list(15, 40)) +}) + +test_that("mapPartitions on RDD", { + sums <- mapPartitions(rdd, function(part) { sum(unlist(part)) }) + actual <- collect(sums) + expect_equal(actual, list(15, 40)) +}) + +test_that("flatMap() on RDDs", { + flat <- flatMap(intRdd, function(x) { list(x, x) }) + actual <- collect(flat) + expect_equal(actual, rep(intPairs, each=2)) +}) + +test_that("filterRDD on RDD", { + filtered.rdd <- filterRDD(rdd, function(x) { x %% 2 == 0 }) + actual <- collect(filtered.rdd) + expect_equal(actual, list(2, 4, 6, 8, 10)) + + filtered.rdd <- Filter(function(x) { x[[2]] < 0 }, intRdd) + actual <- collect(filtered.rdd) + expect_equal(actual, list(list(1L, -1))) + + # Filter out all elements. + filtered.rdd <- filterRDD(rdd, function(x) { x > 10 }) + actual <- collect(filtered.rdd) + expect_equal(actual, list()) +}) + +test_that("lookup on RDD", { + vals <- lookup(intRdd, 1L) + expect_equal(vals, list(-1, 200)) + + vals <- lookup(intRdd, 3L) + expect_equal(vals, list()) +}) + +test_that("several transformations on RDD (a benchmark on PipelinedRDD)", { + rdd2 <- rdd + for (i in 1:12) + rdd2 <- lapplyPartitionsWithIndex( + rdd2, function(split, part) { + part <- as.list(unlist(part) * split + i) + }) + rdd2 <- lapply(rdd2, function(x) x + x) + actual <- collect(rdd2) + expected <- list(24, 24, 24, 24, 24, + 168, 170, 172, 174, 176) + expect_equal(actual, expected) +}) + +test_that("PipelinedRDD support actions: cache(), persist(), unpersist(), checkpoint()", { + # RDD + rdd2 <- rdd + # PipelinedRDD + rdd2 <- lapplyPartitionsWithIndex( + rdd2, + function(split, part) { + part <- as.list(unlist(part) * split) + }) + + cache(rdd2) + expect_true(rdd2@env$isCached) + rdd2 <- lapply(rdd2, function(x) x) + expect_false(rdd2@env$isCached) + + unpersist(rdd2) + expect_false(rdd2@env$isCached) + + persist(rdd2, "MEMORY_AND_DISK") + expect_true(rdd2@env$isCached) + rdd2 <- lapply(rdd2, function(x) x) + expect_false(rdd2@env$isCached) + + unpersist(rdd2) + expect_false(rdd2@env$isCached) + + setCheckpointDir(sc, "checkpoints") + checkpoint(rdd2) + expect_true(rdd2@env$isCheckpointed) + + rdd2 <- lapply(rdd2, function(x) x) + expect_false(rdd2@env$isCached) + expect_false(rdd2@env$isCheckpointed) + + # make sure the data is collectable + collect(rdd2) + + unlink("checkpoints") +}) + +test_that("reduce on RDD", { + sum <- reduce(rdd, "+") + expect_equal(sum, 55) + + # Also test with an inline function + sumInline <- reduce(rdd, function(x, y) { x + y }) + expect_equal(sumInline, 55) +}) + +test_that("lapply with dependency", { + fa <- 5 + multiples <- lapply(rdd, function(x) { fa * x }) + actual <- collect(multiples) + + expect_equal(actual, as.list(nums * 5)) +}) + +test_that("lapplyPartitionsWithIndex on RDDs", { + func <- function(splitIndex, part) { list(splitIndex, Reduce("+", part)) } + actual <- collect(lapplyPartitionsWithIndex(rdd, func), flatten = FALSE) + expect_equal(actual, list(list(0, 15), list(1, 40))) + + pairsRDD <- parallelize(sc, list(list(1, 2), list(3, 4), list(4, 8)), 1L) + partitionByParity <- function(key) { if (key %% 2 == 1) 0 else 1 } + mkTup <- function(splitIndex, part) { list(splitIndex, part) } + actual <- collect(lapplyPartitionsWithIndex( + partitionBy(pairsRDD, 2L, partitionByParity), + mkTup), + FALSE) + expect_equal(actual, list(list(0, list(list(1, 2), list(3, 4))), + list(1, list(list(4, 8))))) +}) + +test_that("sampleRDD() on RDDs", { + expect_equal(unlist(collect(sampleRDD(rdd, FALSE, 1.0, 2014L))), nums) +}) + +test_that("takeSample() on RDDs", { + # ported from RDDSuite.scala, modified seeds + data <- parallelize(sc, 1:100, 2L) + for (seed in 4:5) { + s <- takeSample(data, FALSE, 20L, seed) + expect_equal(length(s), 20L) + expect_equal(length(unique(s)), 20L) + for (elem in s) { + expect_true(elem >= 1 && elem <= 100) + } + } + for (seed in 4:5) { + s <- takeSample(data, FALSE, 200L, seed) + expect_equal(length(s), 100L) + expect_equal(length(unique(s)), 100L) + for (elem in s) { + expect_true(elem >= 1 && elem <= 100) + } + } + for (seed in 4:5) { + s <- takeSample(data, TRUE, 20L, seed) + expect_equal(length(s), 20L) + for (elem in s) { + expect_true(elem >= 1 && elem <= 100) + } + } + for (seed in 4:5) { + s <- takeSample(data, TRUE, 100L, seed) + expect_equal(length(s), 100L) + # Chance of getting all distinct elements is astronomically low, so test we + # got < 100 + expect_true(length(unique(s)) < 100L) + } + for (seed in 4:5) { + s <- takeSample(data, TRUE, 200L, seed) + expect_equal(length(s), 200L) + # Chance of getting all distinct elements is still quite low, so test we + # got < 100 + expect_true(length(unique(s)) < 100L) + } +}) + +test_that("mapValues() on pairwise RDDs", { + multiples <- mapValues(intRdd, function(x) { x * 2 }) + actual <- collect(multiples) + expected <- lapply(intPairs, function(x) { + list(x[[1]], x[[2]] * 2) + }) + expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) +}) + +test_that("flatMapValues() on pairwise RDDs", { + l <- parallelize(sc, list(list(1, c(1,2)), list(2, c(3,4)))) + actual <- collect(flatMapValues(l, function(x) { x })) + expect_equal(actual, list(list(1,1), list(1,2), list(2,3), list(2,4))) + + # Generate x to x+1 for every value + actual <- collect(flatMapValues(intRdd, function(x) { x:(x + 1) })) + expect_equal(actual, + list(list(1L, -1), list(1L, 0), list(2L, 100), list(2L, 101), + list(2L, 1), list(2L, 2), list(1L, 200), list(1L, 201))) +}) + +test_that("reduceByKeyLocally() on PairwiseRDDs", { + pairs <- parallelize(sc, list(list(1, 2), list(1.1, 3), list(1, 4)), 2L) + actual <- reduceByKeyLocally(pairs, "+") + expect_equal(sortKeyValueList(actual), + sortKeyValueList(list(list(1, 6), list(1.1, 3)))) + + pairs <- parallelize(sc, list(list("abc", 1.2), list(1.1, 0), list("abc", 1.3), + list("bb", 5)), 4L) + actual <- reduceByKeyLocally(pairs, "+") + expect_equal(sortKeyValueList(actual), + sortKeyValueList(list(list("abc", 2.5), list(1.1, 0), list("bb", 5)))) +}) + +test_that("distinct() on RDDs", { + nums.rep2 <- rep(1:10, 2) + rdd.rep2 <- parallelize(sc, nums.rep2, 2L) + uniques <- distinct(rdd.rep2) + actual <- sort(unlist(collect(uniques))) + expect_equal(actual, nums) +}) + +test_that("maximum() on RDDs", { + max <- maximum(rdd) + expect_equal(max, 10) +}) + +test_that("minimum() on RDDs", { + min <- minimum(rdd) + expect_equal(min, 1) +}) + +test_that("sumRDD() on RDDs", { + sum <- sumRDD(rdd) + expect_equal(sum, 55) +}) + +test_that("keyBy on RDDs", { + func <- function(x) { x*x } + keys <- keyBy(rdd, func) + actual <- collect(keys) + expect_equal(actual, lapply(nums, function(x) { list(func(x), x) })) +}) + +test_that("repartition/coalesce on RDDs", { + rdd <- parallelize(sc, 1:20, 4L) # each partition contains 5 elements + + # repartition + r1 <- repartition(rdd, 2) + expect_equal(numPartitions(r1), 2L) + count <- length(collectPartition(r1, 0L)) + expect_true(count >= 8 && count <= 12) + + r2 <- repartition(rdd, 6) + expect_equal(numPartitions(r2), 6L) + count <- length(collectPartition(r2, 0L)) + expect_true(count >=0 && count <= 4) + + # coalesce + r3 <- coalesce(rdd, 1) + expect_equal(numPartitions(r3), 1L) + count <- length(collectPartition(r3, 0L)) + expect_equal(count, 20) +}) + +test_that("sortBy() on RDDs", { + sortedRdd <- sortBy(rdd, function(x) { x * x }, ascending = FALSE) + actual <- collect(sortedRdd) + expect_equal(actual, as.list(sort(nums, decreasing = TRUE))) + + rdd2 <- parallelize(sc, sort(nums, decreasing = TRUE), 2L) + sortedRdd2 <- sortBy(rdd2, function(x) { x * x }) + actual <- collect(sortedRdd2) + expect_equal(actual, as.list(nums)) +}) + +test_that("takeOrdered() on RDDs", { + l <- list(10, 1, 2, 9, 3, 4, 5, 6, 7) + rdd <- parallelize(sc, l) + actual <- takeOrdered(rdd, 6L) + expect_equal(actual, as.list(sort(unlist(l)))[1:6]) + + l <- list("e", "d", "c", "d", "a") + rdd <- parallelize(sc, l) + actual <- takeOrdered(rdd, 3L) + expect_equal(actual, as.list(sort(unlist(l)))[1:3]) +}) + +test_that("top() on RDDs", { + l <- list(10, 1, 2, 9, 3, 4, 5, 6, 7) + rdd <- parallelize(sc, l) + actual <- top(rdd, 6L) + expect_equal(actual, as.list(sort(unlist(l), decreasing = TRUE))[1:6]) + + l <- list("e", "d", "c", "d", "a") + rdd <- parallelize(sc, l) + actual <- top(rdd, 3L) + expect_equal(actual, as.list(sort(unlist(l), decreasing = TRUE))[1:3]) +}) + +test_that("fold() on RDDs", { + actual <- fold(rdd, 0, "+") + expect_equal(actual, Reduce("+", nums, 0)) + + rdd <- parallelize(sc, list()) + actual <- fold(rdd, 0, "+") + expect_equal(actual, 0) +}) + +test_that("aggregateRDD() on RDDs", { + rdd <- parallelize(sc, list(1, 2, 3, 4)) + zeroValue <- list(0, 0) + seqOp <- function(x, y) { list(x[[1]] + y, x[[2]] + 1) } + combOp <- function(x, y) { list(x[[1]] + y[[1]], x[[2]] + y[[2]]) } + actual <- aggregateRDD(rdd, zeroValue, seqOp, combOp) + expect_equal(actual, list(10, 4)) + + rdd <- parallelize(sc, list()) + actual <- aggregateRDD(rdd, zeroValue, seqOp, combOp) + expect_equal(actual, list(0, 0)) +}) + +test_that("zipWithUniqueId() on RDDs", { + rdd <- parallelize(sc, list("a", "b", "c", "d", "e"), 3L) + actual <- collect(zipWithUniqueId(rdd)) + expected <- list(list("a", 0), list("b", 3), list("c", 1), + list("d", 4), list("e", 2)) + expect_equal(actual, expected) + + rdd <- parallelize(sc, list("a", "b", "c", "d", "e"), 1L) + actual <- collect(zipWithUniqueId(rdd)) + expected <- list(list("a", 0), list("b", 1), list("c", 2), + list("d", 3), list("e", 4)) + expect_equal(actual, expected) +}) + +test_that("zipWithIndex() on RDDs", { + rdd <- parallelize(sc, list("a", "b", "c", "d", "e"), 3L) + actual <- collect(zipWithIndex(rdd)) + expected <- list(list("a", 0), list("b", 1), list("c", 2), + list("d", 3), list("e", 4)) + expect_equal(actual, expected) + + rdd <- parallelize(sc, list("a", "b", "c", "d", "e"), 1L) + actual <- collect(zipWithIndex(rdd)) + expected <- list(list("a", 0), list("b", 1), list("c", 2), + list("d", 3), list("e", 4)) + expect_equal(actual, expected) +}) + +test_that("glom() on RDD", { + rdd <- parallelize(sc, as.list(1:4), 2L) + actual <- collect(glom(rdd)) + expect_equal(actual, list(list(1, 2), list(3, 4))) +}) + +test_that("keys() on RDDs", { + keys <- keys(intRdd) + actual <- collect(keys) + expect_equal(actual, lapply(intPairs, function(x) { x[[1]] })) +}) + +test_that("values() on RDDs", { + values <- values(intRdd) + actual <- collect(values) + expect_equal(actual, lapply(intPairs, function(x) { x[[2]] })) +}) + +test_that("pipeRDD() on RDDs", { + actual <- collect(pipeRDD(rdd, "more")) + expected <- as.list(as.character(1:10)) + expect_equal(actual, expected) + + trailed.rdd <- parallelize(sc, c("1", "", "2\n", "3\n\r\n")) + actual <- collect(pipeRDD(trailed.rdd, "sort")) + expected <- list("", "1", "2", "3") + expect_equal(actual, expected) + + rev.nums <- 9:0 + rev.rdd <- parallelize(sc, rev.nums, 2L) + actual <- collect(pipeRDD(rev.rdd, "sort")) + expected <- as.list(as.character(c(5:9, 0:4))) + expect_equal(actual, expected) +}) + +test_that("zipRDD() on RDDs", { + rdd1 <- parallelize(sc, 0:4, 2) + rdd2 <- parallelize(sc, 1000:1004, 2) + actual <- collect(zipRDD(rdd1, rdd2)) + expect_equal(actual, + list(list(0, 1000), list(1, 1001), list(2, 1002), list(3, 1003), list(4, 1004))) + + mockFile = c("Spark is pretty.", "Spark is awesome.") + fileName <- tempfile(pattern="spark-test", fileext=".tmp") + writeLines(mockFile, fileName) + + rdd <- textFile(sc, fileName, 1) + actual <- collect(zipRDD(rdd, rdd)) + expected <- lapply(mockFile, function(x) { list(x ,x) }) + expect_equal(actual, expected) + + rdd1 <- parallelize(sc, 0:1, 1) + actual <- collect(zipRDD(rdd1, rdd)) + expected <- lapply(0:1, function(x) { list(x, mockFile[x + 1]) }) + expect_equal(actual, expected) + + rdd1 <- map(rdd, function(x) { x }) + actual <- collect(zipRDD(rdd, rdd1)) + expected <- lapply(mockFile, function(x) { list(x, x) }) + expect_equal(actual, expected) + + unlink(fileName) +}) + +test_that("join() on pairwise RDDs", { + rdd1 <- parallelize(sc, list(list(1,1), list(2,4))) + rdd2 <- parallelize(sc, list(list(1,2), list(1,3))) + actual <- collect(join(rdd1, rdd2, 2L)) + expect_equal(sortKeyValueList(actual), + sortKeyValueList(list(list(1, list(1, 2)), list(1, list(1, 3))))) + + rdd1 <- parallelize(sc, list(list("a",1), list("b",4))) + rdd2 <- parallelize(sc, list(list("a",2), list("a",3))) + actual <- collect(join(rdd1, rdd2, 2L)) + expect_equal(sortKeyValueList(actual), + sortKeyValueList(list(list("a", list(1, 2)), list("a", list(1, 3))))) + + rdd1 <- parallelize(sc, list(list(1,1), list(2,2))) + rdd2 <- parallelize(sc, list(list(3,3), list(4,4))) + actual <- collect(join(rdd1, rdd2, 2L)) + expect_equal(actual, list()) + + rdd1 <- parallelize(sc, list(list("a",1), list("b",2))) + rdd2 <- parallelize(sc, list(list("c",3), list("d",4))) + actual <- collect(join(rdd1, rdd2, 2L)) + expect_equal(actual, list()) +}) + +test_that("leftOuterJoin() on pairwise RDDs", { + rdd1 <- parallelize(sc, list(list(1,1), list(2,4))) + rdd2 <- parallelize(sc, list(list(1,2), list(1,3))) + actual <- collect(leftOuterJoin(rdd1, rdd2, 2L)) + expected <- list(list(1, list(1, 2)), list(1, list(1, 3)), list(2, list(4, NULL))) + expect_equal(sortKeyValueList(actual), + sortKeyValueList(expected)) + + rdd1 <- parallelize(sc, list(list("a",1), list("b",4))) + rdd2 <- parallelize(sc, list(list("a",2), list("a",3))) + actual <- collect(leftOuterJoin(rdd1, rdd2, 2L)) + expected <- list(list("b", list(4, NULL)), list("a", list(1, 2)), list("a", list(1, 3))) + expect_equal(sortKeyValueList(actual), + sortKeyValueList(expected)) + + rdd1 <- parallelize(sc, list(list(1,1), list(2,2))) + rdd2 <- parallelize(sc, list(list(3,3), list(4,4))) + actual <- collect(leftOuterJoin(rdd1, rdd2, 2L)) + expected <- list(list(1, list(1, NULL)), list(2, list(2, NULL))) + expect_equal(sortKeyValueList(actual), + sortKeyValueList(expected)) + + rdd1 <- parallelize(sc, list(list("a",1), list("b",2))) + rdd2 <- parallelize(sc, list(list("c",3), list("d",4))) + actual <- collect(leftOuterJoin(rdd1, rdd2, 2L)) + expected <- list(list("b", list(2, NULL)), list("a", list(1, NULL))) + expect_equal(sortKeyValueList(actual), + sortKeyValueList(expected)) +}) + +test_that("rightOuterJoin() on pairwise RDDs", { + rdd1 <- parallelize(sc, list(list(1,2), list(1,3))) + rdd2 <- parallelize(sc, list(list(1,1), list(2,4))) + actual <- collect(rightOuterJoin(rdd1, rdd2, 2L)) + expected <- list(list(1, list(2, 1)), list(1, list(3, 1)), list(2, list(NULL, 4))) + expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) + + rdd1 <- parallelize(sc, list(list("a",2), list("a",3))) + rdd2 <- parallelize(sc, list(list("a",1), list("b",4))) + actual <- collect(rightOuterJoin(rdd1, rdd2, 2L)) + expected <- list(list("b", list(NULL, 4)), list("a", list(2, 1)), list("a", list(3, 1))) + expect_equal(sortKeyValueList(actual), + sortKeyValueList(expected)) + + rdd1 <- parallelize(sc, list(list(1,1), list(2,2))) + rdd2 <- parallelize(sc, list(list(3,3), list(4,4))) + actual <- collect(rightOuterJoin(rdd1, rdd2, 2L)) + expect_equal(sortKeyValueList(actual), + sortKeyValueList(list(list(3, list(NULL, 3)), list(4, list(NULL, 4))))) + + rdd1 <- parallelize(sc, list(list("a",1), list("b",2))) + rdd2 <- parallelize(sc, list(list("c",3), list("d",4))) + actual <- collect(rightOuterJoin(rdd1, rdd2, 2L)) + expect_equal(sortKeyValueList(actual), + sortKeyValueList(list(list("d", list(NULL, 4)), list("c", list(NULL, 3))))) +}) + +test_that("fullOuterJoin() on pairwise RDDs", { + rdd1 <- parallelize(sc, list(list(1,2), list(1,3), list(3,3))) + rdd2 <- parallelize(sc, list(list(1,1), list(2,4))) + actual <- collect(fullOuterJoin(rdd1, rdd2, 2L)) + expected <- list(list(1, list(2, 1)), list(1, list(3, 1)), list(2, list(NULL, 4)), list(3, list(3, NULL))) + expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) + + rdd1 <- parallelize(sc, list(list("a",2), list("a",3), list("c", 1))) + rdd2 <- parallelize(sc, list(list("a",1), list("b",4))) + actual <- collect(fullOuterJoin(rdd1, rdd2, 2L)) + expected <- list(list("b", list(NULL, 4)), list("a", list(2, 1)), list("a", list(3, 1)), list("c", list(1, NULL))) + expect_equal(sortKeyValueList(actual), + sortKeyValueList(expected)) + + rdd1 <- parallelize(sc, list(list(1,1), list(2,2))) + rdd2 <- parallelize(sc, list(list(3,3), list(4,4))) + actual <- collect(fullOuterJoin(rdd1, rdd2, 2L)) + expect_equal(sortKeyValueList(actual), + sortKeyValueList(list(list(1, list(1, NULL)), list(2, list(2, NULL)), list(3, list(NULL, 3)), list(4, list(NULL, 4))))) + + rdd1 <- parallelize(sc, list(list("a",1), list("b",2))) + rdd2 <- parallelize(sc, list(list("c",3), list("d",4))) + actual <- collect(fullOuterJoin(rdd1, rdd2, 2L)) + expect_equal(sortKeyValueList(actual), + sortKeyValueList(list(list("a", list(1, NULL)), list("b", list(2, NULL)), list("d", list(NULL, 4)), list("c", list(NULL, 3))))) +}) + +test_that("sortByKey() on pairwise RDDs", { + numPairsRdd <- map(rdd, function(x) { list (x, x) }) + sortedRdd <- sortByKey(numPairsRdd, ascending = FALSE) + actual <- collect(sortedRdd) + numPairs <- lapply(nums, function(x) { list (x, x) }) + expect_equal(actual, sortKeyValueList(numPairs, decreasing = TRUE)) + + rdd2 <- parallelize(sc, sort(nums, decreasing = TRUE), 2L) + numPairsRdd2 <- map(rdd2, function(x) { list (x, x) }) + sortedRdd2 <- sortByKey(numPairsRdd2) + actual <- collect(sortedRdd2) + expect_equal(actual, numPairs) + + # sort by string keys + l <- list(list("a", 1), list("b", 2), list("1", 3), list("d", 4), list("2", 5)) + rdd3 <- parallelize(sc, l, 2L) + sortedRdd3 <- sortByKey(rdd3) + actual <- collect(sortedRdd3) + expect_equal(actual, list(list("1", 3), list("2", 5), list("a", 1), list("b", 2), list("d", 4))) + + # test on the boundary cases + + # boundary case 1: the RDD to be sorted has only 1 partition + rdd4 <- parallelize(sc, l, 1L) + sortedRdd4 <- sortByKey(rdd4) + actual <- collect(sortedRdd4) + expect_equal(actual, list(list("1", 3), list("2", 5), list("a", 1), list("b", 2), list("d", 4))) + + # boundary case 2: the sorted RDD has only 1 partition + rdd5 <- parallelize(sc, l, 2L) + sortedRdd5 <- sortByKey(rdd5, numPartitions = 1L) + actual <- collect(sortedRdd5) + expect_equal(actual, list(list("1", 3), list("2", 5), list("a", 1), list("b", 2), list("d", 4))) + + # boundary case 3: the RDD to be sorted has only 1 element + l2 <- list(list("a", 1)) + rdd6 <- parallelize(sc, l2, 2L) + sortedRdd6 <- sortByKey(rdd6) + actual <- collect(sortedRdd6) + expect_equal(actual, l2) + + # boundary case 4: the RDD to be sorted has 0 element + l3 <- list() + rdd7 <- parallelize(sc, l3, 2L) + sortedRdd7 <- sortByKey(rdd7) + actual <- collect(sortedRdd7) + expect_equal(actual, l3) +}) + +test_that("collectAsMap() on a pairwise RDD", { + rdd <- parallelize(sc, list(list(1, 2), list(3, 4))) + vals <- collectAsMap(rdd) + expect_equal(vals, list(`1` = 2, `3` = 4)) + + rdd <- parallelize(sc, list(list("a", 1), list("b", 2))) + vals <- collectAsMap(rdd) + expect_equal(vals, list(a = 1, b = 2)) + + rdd <- parallelize(sc, list(list(1.1, 2.2), list(1.2, 2.4))) + vals <- collectAsMap(rdd) + expect_equal(vals, list(`1.1` = 2.2, `1.2` = 2.4)) + + rdd <- parallelize(sc, list(list(1, "a"), list(2, "b"))) + vals <- collectAsMap(rdd) + expect_equal(vals, list(`1` = "a", `2` = "b")) +}) diff --git a/R/pkg/inst/tests/test_shuffle.R b/R/pkg/inst/tests/test_shuffle.R new file mode 100644 index 0000000000000..d1da8232aea81 --- /dev/null +++ b/R/pkg/inst/tests/test_shuffle.R @@ -0,0 +1,209 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +context("partitionBy, groupByKey, reduceByKey etc.") + +# JavaSparkContext handle +sc <- sparkR.init() + +# Data +intPairs <- list(list(1L, -1), list(2L, 100), list(2L, 1), list(1L, 200)) +intRdd <- parallelize(sc, intPairs, 2L) + +doublePairs <- list(list(1.5, -1), list(2.5, 100), list(2.5, 1), list(1.5, 200)) +doubleRdd <- parallelize(sc, doublePairs, 2L) + +numPairs <- list(list(1L, 100), list(2L, 200), list(4L, -1), list(3L, 1), + list(3L, 0)) +numPairsRdd <- parallelize(sc, numPairs, length(numPairs)) + +strList <- list("Dexter Morgan: Blood. Sometimes it sets my teeth on edge and ", + "Dexter Morgan: Harry and Dorris Morgan did a wonderful job ") +strListRDD <- parallelize(sc, strList, 4) + +test_that("groupByKey for integers", { + grouped <- groupByKey(intRdd, 2L) + + actual <- collect(grouped) + + expected <- list(list(2L, list(100, 1)), list(1L, list(-1, 200))) + expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) +}) + +test_that("groupByKey for doubles", { + grouped <- groupByKey(doubleRdd, 2L) + + actual <- collect(grouped) + + expected <- list(list(1.5, list(-1, 200)), list(2.5, list(100, 1))) + expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) +}) + +test_that("reduceByKey for ints", { + reduced <- reduceByKey(intRdd, "+", 2L) + + actual <- collect(reduced) + + expected <- list(list(2L, 101), list(1L, 199)) + expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) +}) + +test_that("reduceByKey for doubles", { + reduced <- reduceByKey(doubleRdd, "+", 2L) + actual <- collect(reduced) + + expected <- list(list(1.5, 199), list(2.5, 101)) + expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) +}) + +test_that("combineByKey for ints", { + reduced <- combineByKey(intRdd, function(x) { x }, "+", "+", 2L) + + actual <- collect(reduced) + + expected <- list(list(2L, 101), list(1L, 199)) + expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) +}) + +test_that("combineByKey for doubles", { + reduced <- combineByKey(doubleRdd, function(x) { x }, "+", "+", 2L) + actual <- collect(reduced) + + expected <- list(list(1.5, 199), list(2.5, 101)) + expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) +}) + +test_that("aggregateByKey", { + # test aggregateByKey for int keys + rdd <- parallelize(sc, list(list(1, 1), list(1, 2), list(2, 3), list(2, 4))) + + zeroValue <- list(0, 0) + seqOp <- function(x, y) { list(x[[1]] + y, x[[2]] + 1) } + combOp <- function(x, y) { list(x[[1]] + y[[1]], x[[2]] + y[[2]]) } + aggregatedRDD <- aggregateByKey(rdd, zeroValue, seqOp, combOp, 2L) + + actual <- collect(aggregatedRDD) + + expected <- list(list(1, list(3, 2)), list(2, list(7, 2))) + expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) + + # test aggregateByKey for string keys + rdd <- parallelize(sc, list(list("a", 1), list("a", 2), list("b", 3), list("b", 4))) + + zeroValue <- list(0, 0) + seqOp <- function(x, y) { list(x[[1]] + y, x[[2]] + 1) } + combOp <- function(x, y) { list(x[[1]] + y[[1]], x[[2]] + y[[2]]) } + aggregatedRDD <- aggregateByKey(rdd, zeroValue, seqOp, combOp, 2L) + + actual <- collect(aggregatedRDD) + + expected <- list(list("a", list(3, 2)), list("b", list(7, 2))) + expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) +}) + +test_that("foldByKey", { + # test foldByKey for int keys + folded <- foldByKey(intRdd, 0, "+", 2L) + + actual <- collect(folded) + + expected <- list(list(2L, 101), list(1L, 199)) + expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) + + # test foldByKey for double keys + folded <- foldByKey(doubleRdd, 0, "+", 2L) + + actual <- collect(folded) + + expected <- list(list(1.5, 199), list(2.5, 101)) + expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) + + # test foldByKey for string keys + stringKeyPairs <- list(list("a", -1), list("b", 100), list("b", 1), list("a", 200)) + + stringKeyRDD <- parallelize(sc, stringKeyPairs) + folded <- foldByKey(stringKeyRDD, 0, "+", 2L) + + actual <- collect(folded) + + expected <- list(list("b", 101), list("a", 199)) + expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) + + # test foldByKey for empty pair RDD + rdd <- parallelize(sc, list()) + folded <- foldByKey(rdd, 0, "+", 2L) + actual <- collect(folded) + expected <- list() + expect_equal(actual, expected) + + # test foldByKey for RDD with only 1 pair + rdd <- parallelize(sc, list(list(1, 1))) + folded <- foldByKey(rdd, 0, "+", 2L) + actual <- collect(folded) + expected <- list(list(1, 1)) + expect_equal(actual, expected) +}) + +test_that("partitionBy() partitions data correctly", { + # Partition by magnitude + partitionByMagnitude <- function(key) { if (key >= 3) 1 else 0 } + + resultRDD <- partitionBy(numPairsRdd, 2L, partitionByMagnitude) + + expected_first <- list(list(1, 100), list(2, 200)) # key < 3 + expected_second <- list(list(4, -1), list(3, 1), list(3, 0)) # key >= 3 + actual_first <- collectPartition(resultRDD, 0L) + actual_second <- collectPartition(resultRDD, 1L) + + expect_equal(sortKeyValueList(actual_first), sortKeyValueList(expected_first)) + expect_equal(sortKeyValueList(actual_second), sortKeyValueList(expected_second)) +}) + +test_that("partitionBy works with dependencies", { + kOne <- 1 + partitionByParity <- function(key) { if (key %% 2 == kOne) 7 else 4 } + + # Partition by parity + resultRDD <- partitionBy(numPairsRdd, numPartitions = 2L, partitionByParity) + + # keys even; 100 %% 2 == 0 + expected_first <- list(list(2, 200), list(4, -1)) + # keys odd; 3 %% 2 == 1 + expected_second <- list(list(1, 100), list(3, 1), list(3, 0)) + actual_first <- collectPartition(resultRDD, 0L) + actual_second <- collectPartition(resultRDD, 1L) + + expect_equal(sortKeyValueList(actual_first), sortKeyValueList(expected_first)) + expect_equal(sortKeyValueList(actual_second), sortKeyValueList(expected_second)) +}) + +test_that("test partitionBy with string keys", { + words <- flatMap(strListRDD, function(line) { strsplit(line, " ")[[1]] }) + wordCount <- lapply(words, function(word) { list(word, 1L) }) + + resultRDD <- partitionBy(wordCount, 2L) + expected_first <- list(list("Dexter", 1), list("Dexter", 1)) + expected_second <- list(list("and", 1), list("and", 1)) + + actual_first <- Filter(function(item) { item[[1]] == "Dexter" }, + collectPartition(resultRDD, 0L)) + actual_second <- Filter(function(item) { item[[1]] == "and" }, + collectPartition(resultRDD, 1L)) + + expect_equal(sortKeyValueList(actual_first), sortKeyValueList(expected_first)) + expect_equal(sortKeyValueList(actual_second), sortKeyValueList(expected_second)) +}) diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R new file mode 100644 index 0000000000000..cf5cf6d1692af --- /dev/null +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -0,0 +1,695 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +library(testthat) + +context("SparkSQL functions") + +# Tests for SparkSQL functions in SparkR + +sc <- sparkR.init() + +sqlCtx <- sparkRSQL.init(sc) + +mockLines <- c("{\"name\":\"Michael\"}", + "{\"name\":\"Andy\", \"age\":30}", + "{\"name\":\"Justin\", \"age\":19}") +jsonPath <- tempfile(pattern="sparkr-test", fileext=".tmp") +parquetPath <- tempfile(pattern="sparkr-test", fileext=".parquet") +writeLines(mockLines, jsonPath) + +test_that("infer types", { + expect_equal(infer_type(1L), "integer") + expect_equal(infer_type(1.0), "double") + expect_equal(infer_type("abc"), "string") + expect_equal(infer_type(TRUE), "boolean") + expect_equal(infer_type(as.Date("2015-03-11")), "date") + expect_equal(infer_type(as.POSIXlt("2015-03-11 12:13:04.043")), "timestamp") + expect_equal(infer_type(c(1L, 2L)), + list(type = 'array', elementType = "integer", containsNull = TRUE)) + expect_equal(infer_type(list(1L, 2L)), + list(type = 'array', elementType = "integer", containsNull = TRUE)) + expect_equal(infer_type(list(a = 1L, b = "2")), + list(type = "struct", + fields = list(list(name = "a", type = "integer", nullable = TRUE), + list(name = "b", type = "string", nullable = TRUE)))) + e <- new.env() + assign("a", 1L, envir = e) + expect_equal(infer_type(e), + list(type = "map", keyType = "string", valueType = "integer", + valueContainsNull = TRUE)) +}) + +test_that("create DataFrame from RDD", { + rdd <- lapply(parallelize(sc, 1:10), function(x) { list(x, as.character(x)) }) + df <- createDataFrame(sqlCtx, rdd, list("a", "b")) + expect_true(inherits(df, "DataFrame")) + expect_true(count(df) == 10) + expect_equal(columns(df), c("a", "b")) + expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) + + df <- createDataFrame(sqlCtx, rdd) + expect_true(inherits(df, "DataFrame")) + expect_equal(columns(df), c("_1", "_2")) + + fields <- list(list(name = "a", type = "integer", nullable = TRUE), + list(name = "b", type = "string", nullable = TRUE)) + schema <- list(type = "struct", fields = fields) + df <- createDataFrame(sqlCtx, rdd, schema) + expect_true(inherits(df, "DataFrame")) + expect_equal(columns(df), c("a", "b")) + expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) + + rdd <- lapply(parallelize(sc, 1:10), function(x) { list(a = x, b = as.character(x)) }) + df <- createDataFrame(sqlCtx, rdd) + expect_true(inherits(df, "DataFrame")) + expect_true(count(df) == 10) + expect_equal(columns(df), c("a", "b")) + expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) +}) + +test_that("toDF", { + rdd <- lapply(parallelize(sc, 1:10), function(x) { list(x, as.character(x)) }) + df <- toDF(rdd, list("a", "b")) + expect_true(inherits(df, "DataFrame")) + expect_true(count(df) == 10) + expect_equal(columns(df), c("a", "b")) + expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) + + df <- toDF(rdd) + expect_true(inherits(df, "DataFrame")) + expect_equal(columns(df), c("_1", "_2")) + + fields <- list(list(name = "a", type = "integer", nullable = TRUE), + list(name = "b", type = "string", nullable = TRUE)) + schema <- list(type = "struct", fields = fields) + df <- toDF(rdd, schema) + expect_true(inherits(df, "DataFrame")) + expect_equal(columns(df), c("a", "b")) + expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) + + rdd <- lapply(parallelize(sc, 1:10), function(x) { list(a = x, b = as.character(x)) }) + df <- toDF(rdd) + expect_true(inherits(df, "DataFrame")) + expect_true(count(df) == 10) + expect_equal(columns(df), c("a", "b")) + expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) +}) + +test_that("create DataFrame from list or data.frame", { + l <- list(list(1, 2), list(3, 4)) + df <- createDataFrame(sqlCtx, l, c("a", "b")) + expect_equal(columns(df), c("a", "b")) + + l <- list(list(a=1, b=2), list(a=3, b=4)) + df <- createDataFrame(sqlCtx, l) + expect_equal(columns(df), c("a", "b")) + + a <- 1:3 + b <- c("a", "b", "c") + ldf <- data.frame(a, b) + df <- createDataFrame(sqlCtx, ldf) + expect_equal(columns(df), c("a", "b")) + expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) + expect_equal(count(df), 3) + ldf2 <- collect(df) + expect_equal(ldf$a, ldf2$a) +}) + +test_that("create DataFrame with different data types", { + l <- list(a = 1L, b = 2, c = TRUE, d = "ss", e = as.Date("2012-12-13"), + f = as.POSIXct("2015-03-15 12:13:14.056")) + df <- createDataFrame(sqlCtx, list(l)) + expect_equal(dtypes(df), list(c("a", "int"), c("b", "double"), c("c", "boolean"), + c("d", "string"), c("e", "date"), c("f", "timestamp"))) + expect_equal(count(df), 1) + expect_equal(collect(df), data.frame(l, stringsAsFactors = FALSE)) +}) + +# TODO: enable this test after fix serialization for nested object +#test_that("create DataFrame with nested array and struct", { +# e <- new.env() +# assign("n", 3L, envir = e) +# l <- list(1:10, list("a", "b"), e, list(a="aa", b=3L)) +# df <- createDataFrame(sqlCtx, list(l), c("a", "b", "c", "d")) +# expect_equal(dtypes(df), list(c("a", "array"), c("b", "array"), +# c("c", "map"), c("d", "struct"))) +# expect_equal(count(df), 1) +# ldf <- collect(df) +# expect_equal(ldf[1,], l[[1]]) +#}) + +test_that("jsonFile() on a local file returns a DataFrame", { + df <- jsonFile(sqlCtx, jsonPath) + expect_true(inherits(df, "DataFrame")) + expect_true(count(df) == 3) +}) + +test_that("jsonRDD() on a RDD with json string", { + rdd <- parallelize(sc, mockLines) + expect_true(count(rdd) == 3) + df <- jsonRDD(sqlCtx, rdd) + expect_true(inherits(df, "DataFrame")) + expect_true(count(df) == 3) + + rdd2 <- flatMap(rdd, function(x) c(x, x)) + df <- jsonRDD(sqlCtx, rdd2) + expect_true(inherits(df, "DataFrame")) + expect_true(count(df) == 6) +}) + +test_that("test cache, uncache and clearCache", { + df <- jsonFile(sqlCtx, jsonPath) + registerTempTable(df, "table1") + cacheTable(sqlCtx, "table1") + uncacheTable(sqlCtx, "table1") + clearCache(sqlCtx) + dropTempTable(sqlCtx, "table1") +}) + +test_that("test tableNames and tables", { + df <- jsonFile(sqlCtx, jsonPath) + registerTempTable(df, "table1") + expect_true(length(tableNames(sqlCtx)) == 1) + df <- tables(sqlCtx) + expect_true(count(df) == 1) + dropTempTable(sqlCtx, "table1") +}) + +test_that("registerTempTable() results in a queryable table and sql() results in a new DataFrame", { + df <- jsonFile(sqlCtx, jsonPath) + registerTempTable(df, "table1") + newdf <- sql(sqlCtx, "SELECT * FROM table1 where name = 'Michael'") + expect_true(inherits(newdf, "DataFrame")) + expect_true(count(newdf) == 1) + dropTempTable(sqlCtx, "table1") +}) + +test_that("insertInto() on a registered table", { + df <- loadDF(sqlCtx, jsonPath, "json") + saveDF(df, parquetPath, "parquet", "overwrite") + dfParquet <- loadDF(sqlCtx, parquetPath, "parquet") + + lines <- c("{\"name\":\"Bob\", \"age\":24}", + "{\"name\":\"James\", \"age\":35}") + jsonPath2 <- tempfile(pattern="jsonPath2", fileext=".tmp") + parquetPath2 <- tempfile(pattern = "parquetPath2", fileext = ".parquet") + writeLines(lines, jsonPath2) + df2 <- loadDF(sqlCtx, jsonPath2, "json") + saveDF(df2, parquetPath2, "parquet", "overwrite") + dfParquet2 <- loadDF(sqlCtx, parquetPath2, "parquet") + + registerTempTable(dfParquet, "table1") + insertInto(dfParquet2, "table1") + expect_true(count(sql(sqlCtx, "select * from table1")) == 5) + expect_true(first(sql(sqlCtx, "select * from table1 order by age"))$name == "Michael") + dropTempTable(sqlCtx, "table1") + + registerTempTable(dfParquet, "table1") + insertInto(dfParquet2, "table1", overwrite = TRUE) + expect_true(count(sql(sqlCtx, "select * from table1")) == 2) + expect_true(first(sql(sqlCtx, "select * from table1 order by age"))$name == "Bob") + dropTempTable(sqlCtx, "table1") +}) + +test_that("table() returns a new DataFrame", { + df <- jsonFile(sqlCtx, jsonPath) + registerTempTable(df, "table1") + tabledf <- table(sqlCtx, "table1") + expect_true(inherits(tabledf, "DataFrame")) + expect_true(count(tabledf) == 3) + dropTempTable(sqlCtx, "table1") +}) + +test_that("toRDD() returns an RRDD", { + df <- jsonFile(sqlCtx, jsonPath) + testRDD <- toRDD(df) + expect_true(inherits(testRDD, "RDD")) + expect_true(count(testRDD) == 3) +}) + +test_that("union on two RDDs created from DataFrames returns an RRDD", { + df <- jsonFile(sqlCtx, jsonPath) + RDD1 <- toRDD(df) + RDD2 <- toRDD(df) + unioned <- unionRDD(RDD1, RDD2) + expect_true(inherits(unioned, "RDD")) + expect_true(SparkR:::getSerializedMode(unioned) == "byte") + expect_true(collect(unioned)[[2]]$name == "Andy") +}) + +test_that("union on mixed serialization types correctly returns a byte RRDD", { + # Byte RDD + nums <- 1:10 + rdd <- parallelize(sc, nums, 2L) + + # String RDD + textLines <- c("Michael", + "Andy, 30", + "Justin, 19") + textPath <- tempfile(pattern="sparkr-textLines", fileext=".tmp") + writeLines(textLines, textPath) + textRDD <- textFile(sc, textPath) + + df <- jsonFile(sqlCtx, jsonPath) + dfRDD <- toRDD(df) + + unionByte <- unionRDD(rdd, dfRDD) + expect_true(inherits(unionByte, "RDD")) + expect_true(SparkR:::getSerializedMode(unionByte) == "byte") + expect_true(collect(unionByte)[[1]] == 1) + expect_true(collect(unionByte)[[12]]$name == "Andy") + + unionString <- unionRDD(textRDD, dfRDD) + expect_true(inherits(unionString, "RDD")) + expect_true(SparkR:::getSerializedMode(unionString) == "byte") + expect_true(collect(unionString)[[1]] == "Michael") + expect_true(collect(unionString)[[5]]$name == "Andy") +}) + +test_that("objectFile() works with row serialization", { + objectPath <- tempfile(pattern="spark-test", fileext=".tmp") + df <- jsonFile(sqlCtx, jsonPath) + dfRDD <- toRDD(df) + saveAsObjectFile(coalesce(dfRDD, 1L), objectPath) + objectIn <- objectFile(sc, objectPath) + + expect_true(inherits(objectIn, "RDD")) + expect_equal(SparkR:::getSerializedMode(objectIn), "byte") + expect_equal(collect(objectIn)[[2]]$age, 30) +}) + +test_that("lapply() on a DataFrame returns an RDD with the correct columns", { + df <- jsonFile(sqlCtx, jsonPath) + testRDD <- lapply(df, function(row) { + row$newCol <- row$age + 5 + row + }) + expect_true(inherits(testRDD, "RDD")) + collected <- collect(testRDD) + expect_true(collected[[1]]$name == "Michael") + expect_true(collected[[2]]$newCol == "35") +}) + +test_that("collect() returns a data.frame", { + df <- jsonFile(sqlCtx, jsonPath) + rdf <- collect(df) + expect_true(is.data.frame(rdf)) + expect_true(names(rdf)[1] == "age") + expect_true(nrow(rdf) == 3) + expect_true(ncol(rdf) == 2) +}) + +test_that("limit() returns DataFrame with the correct number of rows", { + df <- jsonFile(sqlCtx, jsonPath) + dfLimited <- limit(df, 2) + expect_true(inherits(dfLimited, "DataFrame")) + expect_true(count(dfLimited) == 2) +}) + +test_that("collect() and take() on a DataFrame return the same number of rows and columns", { + df <- jsonFile(sqlCtx, jsonPath) + expect_true(nrow(collect(df)) == nrow(take(df, 10))) + expect_true(ncol(collect(df)) == ncol(take(df, 10))) +}) + +test_that("multiple pipeline transformations starting with a DataFrame result in an RDD with the correct values", { + df <- jsonFile(sqlCtx, jsonPath) + first <- lapply(df, function(row) { + row$age <- row$age + 5 + row + }) + second <- lapply(first, function(row) { + row$testCol <- if (row$age == 35 && !is.na(row$age)) TRUE else FALSE + row + }) + expect_true(inherits(second, "RDD")) + expect_true(count(second) == 3) + expect_true(collect(second)[[2]]$age == 35) + expect_true(collect(second)[[2]]$testCol) + expect_false(collect(second)[[3]]$testCol) +}) + +test_that("cache(), persist(), and unpersist() on a DataFrame", { + df <- jsonFile(sqlCtx, jsonPath) + expect_false(df@env$isCached) + cache(df) + expect_true(df@env$isCached) + + unpersist(df) + expect_false(df@env$isCached) + + persist(df, "MEMORY_AND_DISK") + expect_true(df@env$isCached) + + unpersist(df) + expect_false(df@env$isCached) + + # make sure the data is collectable + expect_true(is.data.frame(collect(df))) +}) + +test_that("schema(), dtypes(), columns(), names() return the correct values/format", { + df <- jsonFile(sqlCtx, jsonPath) + testSchema <- schema(df) + expect_true(length(testSchema$fields()) == 2) + expect_true(testSchema$fields()[[1]]$dataType.toString() == "LongType") + expect_true(testSchema$fields()[[2]]$dataType.simpleString() == "string") + expect_true(testSchema$fields()[[1]]$name() == "age") + + testTypes <- dtypes(df) + expect_true(length(testTypes[[1]]) == 2) + expect_true(testTypes[[1]][1] == "age") + + testCols <- columns(df) + expect_true(length(testCols) == 2) + expect_true(testCols[2] == "name") + + testNames <- names(df) + expect_true(length(testNames) == 2) + expect_true(testNames[2] == "name") +}) + +test_that("head() and first() return the correct data", { + df <- jsonFile(sqlCtx, jsonPath) + testHead <- head(df) + expect_true(nrow(testHead) == 3) + expect_true(ncol(testHead) == 2) + + testHead2 <- head(df, 2) + expect_true(nrow(testHead2) == 2) + expect_true(ncol(testHead2) == 2) + + testFirst <- first(df) + expect_true(nrow(testFirst) == 1) +}) + +test_that("distinct() on DataFrames", { + lines <- c("{\"name\":\"Michael\"}", + "{\"name\":\"Andy\", \"age\":30}", + "{\"name\":\"Justin\", \"age\":19}", + "{\"name\":\"Justin\", \"age\":19}") + jsonPathWithDup <- tempfile(pattern="sparkr-test", fileext=".tmp") + writeLines(lines, jsonPathWithDup) + + df <- jsonFile(sqlCtx, jsonPathWithDup) + uniques <- distinct(df) + expect_true(inherits(uniques, "DataFrame")) + expect_true(count(uniques) == 3) +}) + +test_that("sampleDF on a DataFrame", { + df <- jsonFile(sqlCtx, jsonPath) + sampled <- sampleDF(df, FALSE, 1.0) + expect_equal(nrow(collect(sampled)), count(df)) + expect_true(inherits(sampled, "DataFrame")) + sampled2 <- sampleDF(df, FALSE, 0.1) + expect_true(count(sampled2) < 3) +}) + +test_that("select operators", { + df <- select(jsonFile(sqlCtx, jsonPath), "name", "age") + expect_true(inherits(df$name, "Column")) + expect_true(inherits(df[[2]], "Column")) + expect_true(inherits(df[["age"]], "Column")) + + expect_true(inherits(df[,1], "DataFrame")) + expect_equal(columns(df[,1]), c("name")) + expect_equal(columns(df[,"age"]), c("age")) + df2 <- df[,c("age", "name")] + expect_true(inherits(df2, "DataFrame")) + expect_equal(columns(df2), c("age", "name")) + + df$age2 <- df$age + expect_equal(columns(df), c("name", "age", "age2")) + expect_equal(count(where(df, df$age2 == df$age)), 2) + df$age2 <- df$age * 2 + expect_equal(columns(df), c("name", "age", "age2")) + expect_equal(count(where(df, df$age2 == df$age * 2)), 2) +}) + +test_that("select with column", { + df <- jsonFile(sqlCtx, jsonPath) + df1 <- select(df, "name") + expect_true(columns(df1) == c("name")) + expect_true(count(df1) == 3) + + df2 <- select(df, df$age) + expect_true(columns(df2) == c("age")) + expect_true(count(df2) == 3) +}) + +test_that("selectExpr() on a DataFrame", { + df <- jsonFile(sqlCtx, jsonPath) + selected <- selectExpr(df, "age * 2") + expect_true(names(selected) == "(age * 2)") + expect_equal(collect(selected), collect(select(df, df$age * 2L))) + + selected2 <- selectExpr(df, "name as newName", "abs(age) as age") + expect_equal(names(selected2), c("newName", "age")) + expect_true(count(selected2) == 3) +}) + +test_that("column calculation", { + df <- jsonFile(sqlCtx, jsonPath) + d <- collect(select(df, alias(df$age + 1, "age2"))) + expect_true(names(d) == c("age2")) + df2 <- select(df, lower(df$name), abs(df$age)) + expect_true(inherits(df2, "DataFrame")) + expect_true(count(df2) == 3) +}) + +test_that("load() from json file", { + df <- loadDF(sqlCtx, jsonPath, "json") + expect_true(inherits(df, "DataFrame")) + expect_true(count(df) == 3) +}) + +test_that("save() as parquet file", { + df <- loadDF(sqlCtx, jsonPath, "json") + saveDF(df, parquetPath, "parquet", mode="overwrite") + df2 <- loadDF(sqlCtx, parquetPath, "parquet") + expect_true(inherits(df2, "DataFrame")) + expect_true(count(df2) == 3) +}) + +test_that("test HiveContext", { + hiveCtx <- tryCatch({ + newJObject("org.apache.spark.sql.hive.test.TestHiveContext", ssc) + }, error = function(err) { + skip("Hive is not build with SparkSQL, skipped") + }) + df <- createExternalTable(hiveCtx, "json", jsonPath, "json") + expect_true(inherits(df, "DataFrame")) + expect_true(count(df) == 3) + df2 <- sql(hiveCtx, "select * from json") + expect_true(inherits(df2, "DataFrame")) + expect_true(count(df2) == 3) + + jsonPath2 <- tempfile(pattern="sparkr-test", fileext=".tmp") + saveAsTable(df, "json", "json", "append", path = jsonPath2) + df3 <- sql(hiveCtx, "select * from json") + expect_true(inherits(df3, "DataFrame")) + expect_true(count(df3) == 6) +}) + +test_that("column operators", { + c <- SparkR:::col("a") + c2 <- (- c + 1 - 2) * 3 / 4.0 + c3 <- (c + c2 - c2) * c2 %% c2 + c4 <- (c > c2) & (c2 <= c3) | (c == c2) & (c2 != c3) +}) + +test_that("column functions", { + c <- SparkR:::col("a") + c2 <- min(c) + max(c) + sum(c) + avg(c) + count(c) + abs(c) + sqrt(c) + c3 <- lower(c) + upper(c) + first(c) + last(c) + c4 <- approxCountDistinct(c) + countDistinct(c) + cast(c, "string") +}) + +test_that("string operators", { + df <- jsonFile(sqlCtx, jsonPath) + expect_equal(count(where(df, like(df$name, "A%"))), 1) + expect_equal(count(where(df, startsWith(df$name, "A"))), 1) + expect_equal(first(select(df, substr(df$name, 1, 2)))[[1]], "Mi") + expect_equal(collect(select(df, cast(df$age, "string")))[[2, 1]], "30") +}) + +test_that("group by", { + df <- jsonFile(sqlCtx, jsonPath) + df1 <- agg(df, name = "max", age = "sum") + expect_true(1 == count(df1)) + df1 <- agg(df, age2 = max(df$age)) + expect_true(1 == count(df1)) + expect_equal(columns(df1), c("age2")) + + gd <- groupBy(df, "name") + expect_true(inherits(gd, "GroupedData")) + df2 <- count(gd) + expect_true(inherits(df2, "DataFrame")) + expect_true(3 == count(df2)) + + df3 <- agg(gd, age = "sum") + expect_true(inherits(df3, "DataFrame")) + expect_true(3 == count(df3)) + + df3 <- agg(gd, age = sum(df$age)) + expect_true(inherits(df3, "DataFrame")) + expect_true(3 == count(df3)) + expect_equal(columns(df3), c("name", "age")) + + df4 <- sum(gd, "age") + expect_true(inherits(df4, "DataFrame")) + expect_true(3 == count(df4)) + expect_true(3 == count(mean(gd, "age"))) + expect_true(3 == count(max(gd, "age"))) +}) + +test_that("sortDF() and orderBy() on a DataFrame", { + df <- jsonFile(sqlCtx, jsonPath) + sorted <- sortDF(df, df$age) + expect_true(collect(sorted)[1,2] == "Michael") + + sorted2 <- sortDF(df, "name") + expect_true(collect(sorted2)[2,"age"] == 19) + + sorted3 <- orderBy(df, asc(df$age)) + expect_true(is.na(first(sorted3)$age)) + expect_true(collect(sorted3)[2, "age"] == 19) + + sorted4 <- orderBy(df, desc(df$name)) + expect_true(first(sorted4)$name == "Michael") + expect_true(collect(sorted4)[3,"name"] == "Andy") +}) + +test_that("filter() on a DataFrame", { + df <- jsonFile(sqlCtx, jsonPath) + filtered <- filter(df, "age > 20") + expect_true(count(filtered) == 1) + expect_true(collect(filtered)$name == "Andy") + filtered2 <- where(df, df$name != "Michael") + expect_true(count(filtered2) == 2) + expect_true(collect(filtered2)$age[2] == 19) +}) + +test_that("join() on a DataFrame", { + df <- jsonFile(sqlCtx, jsonPath) + + mockLines2 <- c("{\"name\":\"Michael\", \"test\": \"yes\"}", + "{\"name\":\"Andy\", \"test\": \"no\"}", + "{\"name\":\"Justin\", \"test\": \"yes\"}", + "{\"name\":\"Bob\", \"test\": \"yes\"}") + jsonPath2 <- tempfile(pattern="sparkr-test", fileext=".tmp") + writeLines(mockLines2, jsonPath2) + df2 <- jsonFile(sqlCtx, jsonPath2) + + joined <- join(df, df2) + expect_equal(names(joined), c("age", "name", "name", "test")) + expect_true(count(joined) == 12) + + joined2 <- join(df, df2, df$name == df2$name) + expect_equal(names(joined2), c("age", "name", "name", "test")) + expect_true(count(joined2) == 3) + + joined3 <- join(df, df2, df$name == df2$name, "right_outer") + expect_equal(names(joined3), c("age", "name", "name", "test")) + expect_true(count(joined3) == 4) + expect_true(is.na(collect(orderBy(joined3, joined3$age))$age[2])) + + joined4 <- select(join(df, df2, df$name == df2$name, "outer"), + alias(df$age + 5, "newAge"), df$name, df2$test) + expect_equal(names(joined4), c("newAge", "name", "test")) + expect_true(count(joined4) == 4) + expect_equal(collect(orderBy(joined4, joined4$name))$newAge[3], 24) +}) + +test_that("toJSON() returns an RDD of the correct values", { + df <- jsonFile(sqlCtx, jsonPath) + testRDD <- toJSON(df) + expect_true(inherits(testRDD, "RDD")) + expect_true(SparkR:::getSerializedMode(testRDD) == "string") + expect_equal(collect(testRDD)[[1]], mockLines[1]) +}) + +test_that("showDF()", { + df <- jsonFile(sqlCtx, jsonPath) + expect_output(showDF(df), "age name \nnull Michael\n30 Andy \n19 Justin ") +}) + +test_that("isLocal()", { + df <- jsonFile(sqlCtx, jsonPath) + expect_false(isLocal(df)) +}) + +test_that("unionAll(), subtract(), and intersect() on a DataFrame", { + df <- jsonFile(sqlCtx, jsonPath) + + lines <- c("{\"name\":\"Bob\", \"age\":24}", + "{\"name\":\"Andy\", \"age\":30}", + "{\"name\":\"James\", \"age\":35}") + jsonPath2 <- tempfile(pattern="sparkr-test", fileext=".tmp") + writeLines(lines, jsonPath2) + df2 <- loadDF(sqlCtx, jsonPath2, "json") + + unioned <- sortDF(unionAll(df, df2), df$age) + expect_true(inherits(unioned, "DataFrame")) + expect_true(count(unioned) == 6) + expect_true(first(unioned)$name == "Michael") + + subtracted <- sortDF(subtract(df, df2), desc(df$age)) + expect_true(inherits(unioned, "DataFrame")) + expect_true(count(subtracted) == 2) + expect_true(first(subtracted)$name == "Justin") + + intersected <- sortDF(intersect(df, df2), df$age) + expect_true(inherits(unioned, "DataFrame")) + expect_true(count(intersected) == 1) + expect_true(first(intersected)$name == "Andy") +}) + +test_that("withColumn() and withColumnRenamed()", { + df <- jsonFile(sqlCtx, jsonPath) + newDF <- withColumn(df, "newAge", df$age + 2) + expect_true(length(columns(newDF)) == 3) + expect_true(columns(newDF)[3] == "newAge") + expect_true(first(filter(newDF, df$name != "Michael"))$newAge == 32) + + newDF2 <- withColumnRenamed(df, "age", "newerAge") + expect_true(length(columns(newDF2)) == 2) + expect_true(columns(newDF2)[1] == "newerAge") +}) + +test_that("saveDF() on DataFrame and works with parquetFile", { + df <- jsonFile(sqlCtx, jsonPath) + saveDF(df, parquetPath, "parquet", mode="overwrite") + parquetDF <- parquetFile(sqlCtx, parquetPath) + expect_true(inherits(parquetDF, "DataFrame")) + expect_equal(count(df), count(parquetDF)) +}) + +test_that("parquetFile works with multiple input paths", { + df <- jsonFile(sqlCtx, jsonPath) + saveDF(df, parquetPath, "parquet", mode="overwrite") + parquetPath2 <- tempfile(pattern = "parquetPath2", fileext = ".parquet") + saveDF(df, parquetPath2, "parquet", mode="overwrite") + parquetDF <- parquetFile(sqlCtx, parquetPath, parquetPath2) + expect_true(inherits(parquetDF, "DataFrame")) + expect_true(count(parquetDF) == count(df)*2) +}) + +unlink(parquetPath) +unlink(jsonPath) diff --git a/R/pkg/inst/tests/test_take.R b/R/pkg/inst/tests/test_take.R new file mode 100644 index 0000000000000..7f4c7c315d787 --- /dev/null +++ b/R/pkg/inst/tests/test_take.R @@ -0,0 +1,67 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +context("tests RDD function take()") + +# Mock data +numVector <- c(-10:97) +numList <- list(sqrt(1), sqrt(2), sqrt(3), 4 ** 10) +strVector <- c("Dexter Morgan: I suppose I should be upset, even feel", + "violated, but I'm not. No, in fact, I think this is a friendly", + "message, like \"Hey, wanna play?\" and yes, I want to play. ", + "I really, really do.") +strList <- list("Dexter Morgan: Blood. Sometimes it sets my teeth on edge, ", + "other times it helps me control the chaos.", + "Dexter Morgan: Harry and Dorris Morgan did a wonderful job ", + "raising me. But they're both dead now. I didn't kill them. Honest.") + +# JavaSparkContext handle +jsc <- sparkR.init() + +test_that("take() gives back the original elements in correct count and order", { + numVectorRDD <- parallelize(jsc, numVector, 10) + # case: number of elements to take is less than the size of the first partition + expect_equal(take(numVectorRDD, 1), as.list(head(numVector, n = 1))) + # case: number of elements to take is the same as the size of the first partition + expect_equal(take(numVectorRDD, 11), as.list(head(numVector, n = 11))) + # case: number of elements to take is greater than all elements + expect_equal(take(numVectorRDD, length(numVector)), as.list(numVector)) + expect_equal(take(numVectorRDD, length(numVector) + 1), as.list(numVector)) + + numListRDD <- parallelize(jsc, numList, 1) + numListRDD2 <- parallelize(jsc, numList, 4) + expect_equal(take(numListRDD, 3), take(numListRDD2, 3)) + expect_equal(take(numListRDD, 5), take(numListRDD2, 5)) + expect_equal(take(numListRDD, 1), as.list(head(numList, n = 1))) + expect_equal(take(numListRDD2, 999), numList) + + strVectorRDD <- parallelize(jsc, strVector, 2) + strVectorRDD2 <- parallelize(jsc, strVector, 3) + expect_equal(take(strVectorRDD, 4), as.list(strVector)) + expect_equal(take(strVectorRDD2, 2), as.list(head(strVector, n = 2))) + + strListRDD <- parallelize(jsc, strList, 4) + strListRDD2 <- parallelize(jsc, strList, 1) + expect_equal(take(strListRDD, 3), as.list(head(strList, n = 3))) + expect_equal(take(strListRDD2, 1), as.list(head(strList, n = 1))) + + expect_true(length(take(strListRDD, 0)) == 0) + expect_true(length(take(strVectorRDD, 0)) == 0) + expect_true(length(take(numListRDD, 0)) == 0) + expect_true(length(take(numVectorRDD, 0)) == 0) +}) + diff --git a/R/pkg/inst/tests/test_textFile.R b/R/pkg/inst/tests/test_textFile.R new file mode 100644 index 0000000000000..7bb3e8003131d --- /dev/null +++ b/R/pkg/inst/tests/test_textFile.R @@ -0,0 +1,162 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +context("the textFile() function") + +# JavaSparkContext handle +sc <- sparkR.init() + +mockFile = c("Spark is pretty.", "Spark is awesome.") + +test_that("textFile() on a local file returns an RDD", { + fileName <- tempfile(pattern="spark-test", fileext=".tmp") + writeLines(mockFile, fileName) + + rdd <- textFile(sc, fileName) + expect_true(inherits(rdd, "RDD")) + expect_true(count(rdd) > 0) + expect_true(count(rdd) == 2) + + unlink(fileName) +}) + +test_that("textFile() followed by a collect() returns the same content", { + fileName <- tempfile(pattern="spark-test", fileext=".tmp") + writeLines(mockFile, fileName) + + rdd <- textFile(sc, fileName) + expect_equal(collect(rdd), as.list(mockFile)) + + unlink(fileName) +}) + +test_that("textFile() word count works as expected", { + fileName <- tempfile(pattern="spark-test", fileext=".tmp") + writeLines(mockFile, fileName) + + rdd <- textFile(sc, fileName) + + words <- flatMap(rdd, function(line) { strsplit(line, " ")[[1]] }) + wordCount <- lapply(words, function(word) { list(word, 1L) }) + + counts <- reduceByKey(wordCount, "+", 2L) + output <- collect(counts) + expected <- list(list("pretty.", 1), list("is", 2), list("awesome.", 1), + list("Spark", 2)) + expect_equal(sortKeyValueList(output), sortKeyValueList(expected)) + + unlink(fileName) +}) + +test_that("several transformations on RDD created by textFile()", { + fileName <- tempfile(pattern="spark-test", fileext=".tmp") + writeLines(mockFile, fileName) + + rdd <- textFile(sc, fileName) # RDD + for (i in 1:10) { + # PipelinedRDD initially created from RDD + rdd <- lapply(rdd, function(x) paste(x, x)) + } + collect(rdd) + + unlink(fileName) +}) + +test_that("textFile() followed by a saveAsTextFile() returns the same content", { + fileName1 <- tempfile(pattern="spark-test", fileext=".tmp") + fileName2 <- tempfile(pattern="spark-test", fileext=".tmp") + writeLines(mockFile, fileName1) + + rdd <- textFile(sc, fileName1) + saveAsTextFile(rdd, fileName2) + rdd <- textFile(sc, fileName2) + expect_equal(collect(rdd), as.list(mockFile)) + + unlink(fileName1) + unlink(fileName2) +}) + +test_that("saveAsTextFile() on a parallelized list works as expected", { + fileName <- tempfile(pattern="spark-test", fileext=".tmp") + l <- list(1, 2, 3) + rdd <- parallelize(sc, l) + saveAsTextFile(rdd, fileName) + rdd <- textFile(sc, fileName) + expect_equal(collect(rdd), lapply(l, function(x) {toString(x)})) + + unlink(fileName) +}) + +test_that("textFile() and saveAsTextFile() word count works as expected", { + fileName1 <- tempfile(pattern="spark-test", fileext=".tmp") + fileName2 <- tempfile(pattern="spark-test", fileext=".tmp") + writeLines(mockFile, fileName1) + + rdd <- textFile(sc, fileName1) + + words <- flatMap(rdd, function(line) { strsplit(line, " ")[[1]] }) + wordCount <- lapply(words, function(word) { list(word, 1L) }) + + counts <- reduceByKey(wordCount, "+", 2L) + + saveAsTextFile(counts, fileName2) + rdd <- textFile(sc, fileName2) + + output <- collect(rdd) + expected <- list(list("awesome.", 1), list("Spark", 2), + list("pretty.", 1), list("is", 2)) + expectedStr <- lapply(expected, function(x) { toString(x) }) + expect_equal(sortKeyValueList(output), sortKeyValueList(expectedStr)) + + unlink(fileName1) + unlink(fileName2) +}) + +test_that("textFile() on multiple paths", { + fileName1 <- tempfile(pattern="spark-test", fileext=".tmp") + fileName2 <- tempfile(pattern="spark-test", fileext=".tmp") + writeLines("Spark is pretty.", fileName1) + writeLines("Spark is awesome.", fileName2) + + rdd <- textFile(sc, c(fileName1, fileName2)) + expect_true(count(rdd) == 2) + + unlink(fileName1) + unlink(fileName2) +}) + +test_that("Pipelined operations on RDDs created using textFile", { + fileName <- tempfile(pattern="spark-test", fileext=".tmp") + writeLines(mockFile, fileName) + + rdd <- textFile(sc, fileName) + + lengths <- lapply(rdd, function(x) { length(x) }) + expect_equal(collect(lengths), list(1, 1)) + + lengthsPipelined <- lapply(lengths, function(x) { x + 10 }) + expect_equal(collect(lengthsPipelined), list(11, 11)) + + lengths30 <- lapply(lengthsPipelined, function(x) { x + 20 }) + expect_equal(collect(lengths30), list(31, 31)) + + lengths20 <- lapply(lengths, function(x) { x + 20 }) + expect_equal(collect(lengths20), list(21, 21)) + + unlink(fileName) +}) + diff --git a/R/pkg/inst/tests/test_utils.R b/R/pkg/inst/tests/test_utils.R new file mode 100644 index 0000000000000..9c5bb427932b4 --- /dev/null +++ b/R/pkg/inst/tests/test_utils.R @@ -0,0 +1,137 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +context("functions in utils.R") + +# JavaSparkContext handle +sc <- sparkR.init() + +test_that("convertJListToRList() gives back (deserializes) the original JLists + of strings and integers", { + # It's hard to manually create a Java List using rJava, since it does not + # support generics well. Instead, we rely on collect() returning a + # JList. + nums <- as.list(1:10) + rdd <- parallelize(sc, nums, 1L) + jList <- callJMethod(rdd@jrdd, "collect") + rList <- convertJListToRList(jList, flatten = TRUE) + expect_equal(rList, nums) + + strs <- as.list("hello", "spark") + rdd <- parallelize(sc, strs, 2L) + jList <- callJMethod(rdd@jrdd, "collect") + rList <- convertJListToRList(jList, flatten = TRUE) + expect_equal(rList, strs) +}) + +test_that("serializeToBytes on RDD", { + # File content + mockFile <- c("Spark is pretty.", "Spark is awesome.") + fileName <- tempfile(pattern="spark-test", fileext=".tmp") + writeLines(mockFile, fileName) + + text.rdd <- textFile(sc, fileName) + expect_true(getSerializedMode(text.rdd) == "string") + ser.rdd <- serializeToBytes(text.rdd) + expect_equal(collect(ser.rdd), as.list(mockFile)) + expect_true(getSerializedMode(ser.rdd) == "byte") + + unlink(fileName) +}) + +test_that("cleanClosure on R functions", { + y <- c(1, 2, 3) + g <- function(x) { x + 1 } + f <- function(x) { g(x) + y } + newF <- cleanClosure(f) + env <- environment(newF) + expect_equal(length(ls(env)), 2) # y, g + actual <- get("y", envir = env, inherits = FALSE) + expect_equal(actual, y) + actual <- get("g", envir = env, inherits = FALSE) + expect_equal(actual, g) + + # Test for nested enclosures and package variables. + env2 <- new.env() + funcEnv <- new.env(parent = env2) + f <- function(x) { log(g(x) + y) } + environment(f) <- funcEnv # enclosing relationship: f -> funcEnv -> env2 -> .GlobalEnv + newF <- cleanClosure(f) + env <- environment(newF) + expect_equal(length(ls(env)), 2) # "min" should not be included + actual <- get("y", envir = env, inherits = FALSE) + expect_equal(actual, y) + actual <- get("g", envir = env, inherits = FALSE) + expect_equal(actual, g) + + base <- c(1, 2, 3) + l <- list(field = matrix(1)) + field <- matrix(2) + defUse <- 3 + g <- function(x) { x + y } + f <- function(x) { + defUse <- base::as.integer(x) + 1 # Test for access operators `::`. + lapply(x, g) + 1 # Test for capturing function call "g"'s closure as a argument of lapply. + l$field[1,1] <- 3 # Test for access operators `$`. + res <- defUse + l$field[1,] # Test for def-use chain of "defUse", and "" symbol. + f(res) # Test for recursive calls. + } + newF <- cleanClosure(f) + env <- environment(newF) + expect_equal(length(ls(env)), 3) # Only "g", "l" and "f". No "base", "field" or "defUse". + expect_true("g" %in% ls(env)) + expect_true("l" %in% ls(env)) + expect_true("f" %in% ls(env)) + expect_equal(get("l", envir = env, inherits = FALSE), l) + # "y" should be in the environemnt of g. + newG <- get("g", envir = env, inherits = FALSE) + env <- environment(newG) + expect_equal(length(ls(env)), 1) + actual <- get("y", envir = env, inherits = FALSE) + expect_equal(actual, y) + + # Test for function (and variable) definitions. + f <- function(x) { + g <- function(y) { y * 2 } + g(x) + } + newF <- cleanClosure(f) + env <- environment(newF) + expect_equal(length(ls(env)), 0) # "y" and "g" should not be included. + + # Test for overriding variables in base namespace (Issue: SparkR-196). + nums <- as.list(1:10) + rdd <- parallelize(sc, nums, 2L) + t = 4 # Override base::t in .GlobalEnv. + f <- function(x) { x > t } + newF <- cleanClosure(f) + env <- environment(newF) + expect_equal(ls(env), "t") + expect_equal(get("t", envir = env, inherits = FALSE), t) + actual <- collect(lapply(rdd, f)) + expected <- as.list(c(rep(FALSE, 4), rep(TRUE, 6))) + expect_equal(actual, expected) + + # Test for broadcast variables. + a <- matrix(nrow=10, ncol=10, data=rnorm(100)) + aBroadcast <- broadcast(sc, a) + normMultiply <- function(x) { norm(aBroadcast$value) * x } + newnormMultiply <- SparkR:::cleanClosure(normMultiply) + env <- environment(newnormMultiply) + expect_equal(ls(env), "aBroadcast") + expect_equal(get("aBroadcast", envir = env, inherits = FALSE), aBroadcast) +}) diff --git a/R/pkg/inst/worker/daemon.R b/R/pkg/inst/worker/daemon.R new file mode 100644 index 0000000000000..3584b418a71a9 --- /dev/null +++ b/R/pkg/inst/worker/daemon.R @@ -0,0 +1,52 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Worker daemon + +rLibDir <- Sys.getenv("SPARKR_RLIBDIR") +script <- paste(rLibDir, "SparkR/worker/worker.R", sep = "/") + +# preload SparkR package, speedup worker +.libPaths(c(rLibDir, .libPaths())) +suppressPackageStartupMessages(library(SparkR)) + +port <- as.integer(Sys.getenv("SPARKR_WORKER_PORT")) +inputCon <- socketConnection(port = port, open = "rb", blocking = TRUE, timeout = 3600) + +while (TRUE) { + ready <- socketSelect(list(inputCon)) + if (ready) { + port <- SparkR:::readInt(inputCon) + # There is a small chance that it could be interrupted by signal, retry one time + if (length(port) == 0) { + port <- SparkR:::readInt(inputCon) + if (length(port) == 0) { + cat("quitting daemon\n") + quit(save = "no") + } + } + p <- parallel:::mcfork() + if (inherits(p, "masterProcess")) { + close(inputCon) + Sys.setenv(SPARKR_WORKER_PORT = port) + source(script) + # Set SIGUSR1 so that child can exit + tools::pskill(Sys.getpid(), tools::SIGUSR1) + parallel:::mcexit(0L) + } + } +} diff --git a/R/pkg/inst/worker/worker.R b/R/pkg/inst/worker/worker.R new file mode 100644 index 0000000000000..c6542928e8ddd --- /dev/null +++ b/R/pkg/inst/worker/worker.R @@ -0,0 +1,128 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Worker class + +rLibDir <- Sys.getenv("SPARKR_RLIBDIR") +# Set libPaths to include SparkR package as loadNamespace needs this +# TODO: Figure out if we can avoid this by not loading any objects that require +# SparkR namespace +.libPaths(c(rLibDir, .libPaths())) +suppressPackageStartupMessages(library(SparkR)) + +port <- as.integer(Sys.getenv("SPARKR_WORKER_PORT")) +inputCon <- socketConnection(port = port, blocking = TRUE, open = "rb") +outputCon <- socketConnection(port = port, blocking = TRUE, open = "wb") + +# read the index of the current partition inside the RDD +partition <- SparkR:::readInt(inputCon) + +deserializer <- SparkR:::readString(inputCon) +serializer <- SparkR:::readString(inputCon) + +# Include packages as required +packageNames <- unserialize(SparkR:::readRaw(inputCon)) +for (pkg in packageNames) { + suppressPackageStartupMessages(require(as.character(pkg), character.only=TRUE)) +} + +# read function dependencies +funcLen <- SparkR:::readInt(inputCon) +computeFunc <- unserialize(SparkR:::readRawLen(inputCon, funcLen)) +env <- environment(computeFunc) +parent.env(env) <- .GlobalEnv # Attach under global environment. + +# Read and set broadcast variables +numBroadcastVars <- SparkR:::readInt(inputCon) +if (numBroadcastVars > 0) { + for (bcast in seq(1:numBroadcastVars)) { + bcastId <- SparkR:::readInt(inputCon) + value <- unserialize(SparkR:::readRaw(inputCon)) + setBroadcastValue(bcastId, value) + } +} + +# If -1: read as normal RDD; if >= 0, treat as pairwise RDD and treat the int +# as number of partitions to create. +numPartitions <- SparkR:::readInt(inputCon) + +isEmpty <- SparkR:::readInt(inputCon) + +if (isEmpty != 0) { + + if (numPartitions == -1) { + if (deserializer == "byte") { + # Now read as many characters as described in funcLen + data <- SparkR:::readDeserialize(inputCon) + } else if (deserializer == "string") { + data <- as.list(readLines(inputCon)) + } else if (deserializer == "row") { + data <- SparkR:::readDeserializeRows(inputCon) + } + output <- computeFunc(partition, data) + if (serializer == "byte") { + SparkR:::writeRawSerialize(outputCon, output) + } else if (serializer == "row") { + SparkR:::writeRowSerialize(outputCon, output) + } else { + SparkR:::writeStrings(outputCon, output) + } + } else { + if (deserializer == "byte") { + # Now read as many characters as described in funcLen + data <- SparkR:::readDeserialize(inputCon) + } else if (deserializer == "string") { + data <- readLines(inputCon) + } else if (deserializer == "row") { + data <- SparkR:::readDeserializeRows(inputCon) + } + + res <- new.env() + + # Step 1: hash the data to an environment + hashTupleToEnvir <- function(tuple) { + # NOTE: execFunction is the hash function here + hashVal <- computeFunc(tuple[[1]]) + bucket <- as.character(hashVal %% numPartitions) + acc <- res[[bucket]] + # Create a new accumulator + if (is.null(acc)) { + acc <- SparkR:::initAccumulator() + } + SparkR:::addItemToAccumulator(acc, tuple) + res[[bucket]] <- acc + } + invisible(lapply(data, hashTupleToEnvir)) + + # Step 2: write out all of the environment as key-value pairs. + for (name in ls(res)) { + SparkR:::writeInt(outputCon, 2L) + SparkR:::writeInt(outputCon, as.integer(name)) + # Truncate the accumulator list to the number of elements we have + length(res[[name]]$data) <- res[[name]]$counter + SparkR:::writeRawSerialize(outputCon, res[[name]]$data) + } + } +} + +# End of output +if (serializer %in% c("byte", "row")) { + SparkR:::writeInt(outputCon, 0L) +} + +close(outputCon) +close(inputCon) diff --git a/R/pkg/src/Makefile b/R/pkg/src/Makefile new file mode 100644 index 0000000000000..a55a56fe80e10 --- /dev/null +++ b/R/pkg/src/Makefile @@ -0,0 +1,27 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +all: sharelib + +sharelib: string_hash_code.c + R CMD SHLIB -o SparkR.so string_hash_code.c + +clean: + rm -f *.o + rm -f *.so + +.PHONY: all clean diff --git a/R/pkg/src/Makefile.win b/R/pkg/src/Makefile.win new file mode 100644 index 0000000000000..aa486d8228371 --- /dev/null +++ b/R/pkg/src/Makefile.win @@ -0,0 +1,27 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +all: sharelib + +sharelib: string_hash_code.c + R CMD SHLIB -o SparkR.dll string_hash_code.c + +clean: + rm -f *.o + rm -f *.dll + +.PHONY: all clean diff --git a/R/pkg/src/string_hash_code.c b/R/pkg/src/string_hash_code.c new file mode 100644 index 0000000000000..e3274b9a0c547 --- /dev/null +++ b/R/pkg/src/string_hash_code.c @@ -0,0 +1,49 @@ +/* + Licensed to the Apache Software Foundation (ASF) under one or more + contributor license agreements. See the NOTICE file distributed with + this work for additional information regarding copyright ownership. + The ASF licenses this file to You under the Apache License, Version 2.0 + (the "License"); you may not use this file except in compliance with + the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +/* + * A C function for R extension which implements the Java String hash algorithm. + * Refer to http://en.wikipedia.org/wiki/Java_hashCode%28%29#The_java.lang.String_hash_function + * + */ + +#include +#include + +/* for compatibility with R before 3.1 */ +#ifndef IS_SCALAR +#define IS_SCALAR(x, type) (TYPEOF(x) == (type) && XLENGTH(x) == 1) +#endif + +SEXP stringHashCode(SEXP string) { + const char* str; + R_xlen_t len, i; + int hashCode = 0; + + if (!IS_SCALAR(string, STRSXP)) { + error("invalid input"); + } + + str = CHAR(asChar(string)); + len = XLENGTH(asChar(string)); + + for (i = 0; i < len; i++) { + hashCode = (hashCode << 5) - hashCode + *str++; + } + + return ScalarInteger(hashCode); +} diff --git a/R/pkg/tests/run-all.R b/R/pkg/tests/run-all.R new file mode 100644 index 0000000000000..4f8a1ed2d83ef --- /dev/null +++ b/R/pkg/tests/run-all.R @@ -0,0 +1,21 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +library(testthat) +library(SparkR) + +test_package("SparkR") diff --git a/R/run-tests.sh b/R/run-tests.sh new file mode 100755 index 0000000000000..e82ad0ba2cd06 --- /dev/null +++ b/R/run-tests.sh @@ -0,0 +1,39 @@ +#!/bin/bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +FWDIR="$(cd `dirname $0`; pwd)" + +FAILED=0 +LOGFILE=$FWDIR/unit-tests.out +rm -f $LOGFILE + +SPARK_TESTING=1 $FWDIR/../bin/sparkR --driver-java-options "-Dlog4j.configuration=file:$FWDIR/log4j.properties" $FWDIR/pkg/tests/run-all.R 2>&1 | tee -a $LOGFILE +FAILED=$((PIPESTATUS[0]||$FAILED)) + +if [[ $FAILED != 0 ]]; then + cat $LOGFILE + echo -en "\033[31m" # Red + echo "Had test failures; see logs." + echo -en "\033[0m" # No color + exit -1 +else + echo -en "\033[32m" # Green + echo "Tests passed." + echo -en "\033[0m" # No color +fi diff --git a/bagel/src/test/resources/log4j.properties b/bagel/src/test/resources/log4j.properties index 853ef0ed2986f..edbecdae92096 100644 --- a/bagel/src/test/resources/log4j.properties +++ b/bagel/src/test/resources/log4j.properties @@ -24,4 +24,4 @@ log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.eclipse.jetty=WARN +log4j.logger.org.spark-project.jetty=WARN diff --git a/bin/load-spark-env.sh b/bin/load-spark-env.sh index 2d7070c25d328..95779e9ddbb18 100644 --- a/bin/load-spark-env.sh +++ b/bin/load-spark-env.sh @@ -20,6 +20,7 @@ # This script loads spark-env.sh if it exists, and ensures it is only loaded once. # spark-env.sh is loaded from SPARK_CONF_DIR if set, or within the current directory's # conf/ subdirectory. +FWDIR="$(cd "`dirname "$0"`"/..; pwd)" if [ -z "$SPARK_ENV_LOADED" ]; then export SPARK_ENV_LOADED=1 @@ -41,8 +42,8 @@ fi if [ -z "$SPARK_SCALA_VERSION" ]; then - ASSEMBLY_DIR2="$SPARK_HOME/assembly/target/scala-2.11" - ASSEMBLY_DIR1="$SPARK_HOME/assembly/target/scala-2.10" + ASSEMBLY_DIR2="$FWDIR/assembly/target/scala-2.11" + ASSEMBLY_DIR1="$FWDIR/assembly/target/scala-2.10" if [[ -d "$ASSEMBLY_DIR2" && -d "$ASSEMBLY_DIR1" ]]; then echo -e "Presence of build for both scala versions(SCALA 2.10 and SCALA 2.11) detected." 1>&2 diff --git a/bin/sparkR b/bin/sparkR new file mode 100755 index 0000000000000..8c918e2b09aef --- /dev/null +++ b/bin/sparkR @@ -0,0 +1,39 @@ +#!/bin/bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Figure out where Spark is installed +export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" + +source "$SPARK_HOME"/bin/load-spark-env.sh + +function usage() { + if [ -n "$1" ]; then + echo $1 + fi + echo "Usage: ./bin/sparkR [options]" 1>&2 + "$SPARK_HOME"/bin/spark-submit --help 2>&1 | grep -v Usage 1>&2 + exit $2 +} +export -f usage + +if [[ "$@" = *--help ]] || [[ "$@" = *-h ]]; then + usage +fi + +exec "$SPARK_HOME"/bin/spark-submit sparkr-shell-main "$@" diff --git a/bin/sparkR.cmd b/bin/sparkR.cmd new file mode 100644 index 0000000000000..d7b60183ca8e0 --- /dev/null +++ b/bin/sparkR.cmd @@ -0,0 +1,23 @@ +@echo off + +rem +rem Licensed to the Apache Software Foundation (ASF) under one or more +rem contributor license agreements. See the NOTICE file distributed with +rem this work for additional information regarding copyright ownership. +rem The ASF licenses this file to You under the Apache License, Version 2.0 +rem (the "License"); you may not use this file except in compliance with +rem the License. You may obtain a copy of the License at +rem +rem http://www.apache.org/licenses/LICENSE-2.0 +rem +rem Unless required by applicable law or agreed to in writing, software +rem distributed under the License is distributed on an "AS IS" BASIS, +rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +rem See the License for the specific language governing permissions and +rem limitations under the License. +rem + +rem This is the entry point for running SparkR. To avoid polluting the +rem environment, it just launches a new cmd to do the real work. + +cmd /V /E /C %~dp0sparkR2.cmd %* diff --git a/bin/sparkR2.cmd b/bin/sparkR2.cmd new file mode 100644 index 0000000000000..e47f22c7300bb --- /dev/null +++ b/bin/sparkR2.cmd @@ -0,0 +1,26 @@ +@echo off + +rem +rem Licensed to the Apache Software Foundation (ASF) under one or more +rem contributor license agreements. See the NOTICE file distributed with +rem this work for additional information regarding copyright ownership. +rem The ASF licenses this file to You under the Apache License, Version 2.0 +rem (the "License"); you may not use this file except in compliance with +rem the License. You may obtain a copy of the License at +rem +rem http://www.apache.org/licenses/LICENSE-2.0 +rem +rem Unless required by applicable law or agreed to in writing, software +rem distributed under the License is distributed on an "AS IS" BASIS, +rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +rem See the License for the specific language governing permissions and +rem limitations under the License. +rem + +rem Figure out where the Spark framework is installed +set SPARK_HOME=%~dp0.. + +call %SPARK_HOME%\bin\load-spark-env.cmd + + +call %SPARK_HOME%\bin\spark-submit2.cmd sparkr-shell-main %* diff --git a/conf/log4j.properties.template b/conf/log4j.properties.template index 89eec7d4b7f61..3a2a88219818f 100644 --- a/conf/log4j.properties.template +++ b/conf/log4j.properties.template @@ -6,7 +6,7 @@ log4j.appender.console.layout=org.apache.log4j.PatternLayout log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n # Settings to quiet third party logs that are too verbose -log4j.logger.org.eclipse.jetty=WARN -log4j.logger.org.eclipse.jetty.util.component.AbstractLifeCycle=ERROR +log4j.logger.org.spark-project.jetty=WARN +log4j.logger.org.spark-project.jetty.util.component.AbstractLifeCycle=ERROR log4j.logger.org.apache.spark.repl.SparkIMain$exprTyper=INFO log4j.logger.org.apache.spark.repl.SparkILoop$SparkILoopInterpreter=INFO diff --git a/core/pom.xml b/core/pom.xml index 6cd1965ec37c2..e80829b7a7f3d 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -442,4 +442,55 @@ + + + Windows + + + Windows + + + + \ + .bat + + + + unix + + + unix + + + + / + .sh + + + + sparkr + + + + org.codehaus.mojo + exec-maven-plugin + 1.3.2 + + + sparkr-pkg + compile + + exec + + + + + ..${path.separator}R${path.separator}install-dev${script.extension} + + + + + + + diff --git a/core/src/main/resources/org/apache/spark/log4j-defaults.properties b/core/src/main/resources/org/apache/spark/log4j-defaults.properties index 89eec7d4b7f61..3a2a88219818f 100644 --- a/core/src/main/resources/org/apache/spark/log4j-defaults.properties +++ b/core/src/main/resources/org/apache/spark/log4j-defaults.properties @@ -6,7 +6,7 @@ log4j.appender.console.layout=org.apache.log4j.PatternLayout log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n # Settings to quiet third party logs that are too verbose -log4j.logger.org.eclipse.jetty=WARN -log4j.logger.org.eclipse.jetty.util.component.AbstractLifeCycle=ERROR +log4j.logger.org.spark-project.jetty=WARN +log4j.logger.org.spark-project.jetty.util.component.AbstractLifeCycle=ERROR log4j.logger.org.apache.spark.repl.SparkIMain$exprTyper=INFO log4j.logger.org.apache.spark.repl.SparkILoop$SparkILoopInterpreter=INFO diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackend.scala b/core/src/main/scala/org/apache/spark/api/r/RBackend.scala new file mode 100644 index 0000000000000..3a2c94bd9d875 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/api/r/RBackend.scala @@ -0,0 +1,145 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.r + +import java.io.{DataOutputStream, File, FileOutputStream, IOException} +import java.net.{InetSocketAddress, ServerSocket} +import java.util.concurrent.TimeUnit + +import io.netty.bootstrap.ServerBootstrap +import io.netty.channel.{ChannelFuture, ChannelInitializer, EventLoopGroup} +import io.netty.channel.nio.NioEventLoopGroup +import io.netty.channel.socket.SocketChannel +import io.netty.channel.socket.nio.NioServerSocketChannel +import io.netty.handler.codec.LengthFieldBasedFrameDecoder +import io.netty.handler.codec.bytes.{ByteArrayDecoder, ByteArrayEncoder} + +import org.apache.spark.Logging + +/** + * Netty-based backend server that is used to communicate between R and Java. + */ +private[spark] class RBackend { + + private[this] var channelFuture: ChannelFuture = null + private[this] var bootstrap: ServerBootstrap = null + private[this] var bossGroup: EventLoopGroup = null + + def init(): Int = { + bossGroup = new NioEventLoopGroup(2) + val workerGroup = bossGroup + val handler = new RBackendHandler(this) + + bootstrap = new ServerBootstrap() + .group(bossGroup, workerGroup) + .channel(classOf[NioServerSocketChannel]) + + bootstrap.childHandler(new ChannelInitializer[SocketChannel]() { + def initChannel(ch: SocketChannel): Unit = { + ch.pipeline() + .addLast("encoder", new ByteArrayEncoder()) + .addLast("frameDecoder", + // maxFrameLength = 2G + // lengthFieldOffset = 0 + // lengthFieldLength = 4 + // lengthAdjustment = 0 + // initialBytesToStrip = 4, i.e. strip out the length field itself + new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 4, 0, 4)) + .addLast("decoder", new ByteArrayDecoder()) + .addLast("handler", handler) + } + }) + + channelFuture = bootstrap.bind(new InetSocketAddress(0)) + channelFuture.syncUninterruptibly() + channelFuture.channel().localAddress().asInstanceOf[InetSocketAddress].getPort() + } + + def run(): Unit = { + channelFuture.channel.closeFuture().syncUninterruptibly() + } + + def close(): Unit = { + if (channelFuture != null) { + // close is a local operation and should finish within milliseconds; timeout just to be safe + channelFuture.channel().close().awaitUninterruptibly(10, TimeUnit.SECONDS) + channelFuture = null + } + if (bootstrap != null && bootstrap.group() != null) { + bootstrap.group().shutdownGracefully() + } + if (bootstrap != null && bootstrap.childGroup() != null) { + bootstrap.childGroup().shutdownGracefully() + } + bootstrap = null + } + +} + +private[spark] object RBackend extends Logging { + def main(args: Array[String]): Unit = { + if (args.length < 1) { + System.err.println("Usage: RBackend ") + System.exit(-1) + } + val sparkRBackend = new RBackend() + try { + // bind to random port + val boundPort = sparkRBackend.init() + val serverSocket = new ServerSocket(0, 1) + val listenPort = serverSocket.getLocalPort() + + // tell the R process via temporary file + val path = args(0) + val f = new File(path + ".tmp") + val dos = new DataOutputStream(new FileOutputStream(f)) + dos.writeInt(boundPort) + dos.writeInt(listenPort) + dos.close() + f.renameTo(new File(path)) + + // wait for the end of stdin, then exit + new Thread("wait for socket to close") { + setDaemon(true) + override def run(): Unit = { + // any un-catched exception will also shutdown JVM + val buf = new Array[Byte](1024) + // shutdown JVM if R does not connect back in 10 seconds + serverSocket.setSoTimeout(10000) + try { + val inSocket = serverSocket.accept() + serverSocket.close() + // wait for the end of socket, closed if R process die + inSocket.getInputStream().read(buf) + } finally { + sparkRBackend.close() + System.exit(0) + } + } + }.start() + + sparkRBackend.run() + } catch { + case e: IOException => + logError("Server shutting down: failed with exception ", e) + sparkRBackend.close() + System.exit(1) + } + System.exit(0) + } +} diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala new file mode 100644 index 0000000000000..0075d963711f1 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala @@ -0,0 +1,223 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.r + +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} + +import scala.collection.mutable.HashMap + +import io.netty.channel.ChannelHandler.Sharable +import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler} + +import org.apache.spark.Logging +import org.apache.spark.api.r.SerDe._ + +/** + * Handler for RBackend + * TODO: This is marked as sharable to get a handle to RBackend. Is it safe to re-use + * this across connections ? + */ +@Sharable +private[r] class RBackendHandler(server: RBackend) + extends SimpleChannelInboundHandler[Array[Byte]] with Logging { + + override def channelRead0(ctx: ChannelHandlerContext, msg: Array[Byte]): Unit = { + val bis = new ByteArrayInputStream(msg) + val dis = new DataInputStream(bis) + + val bos = new ByteArrayOutputStream() + val dos = new DataOutputStream(bos) + + // First bit is isStatic + val isStatic = readBoolean(dis) + val objId = readString(dis) + val methodName = readString(dis) + val numArgs = readInt(dis) + + if (objId == "SparkRHandler") { + methodName match { + case "stopBackend" => + writeInt(dos, 0) + writeType(dos, "void") + server.close() + case "rm" => + try { + val t = readObjectType(dis) + assert(t == 'c') + val objToRemove = readString(dis) + JVMObjectTracker.remove(objToRemove) + writeInt(dos, 0) + writeObject(dos, null) + } catch { + case e: Exception => + logError(s"Removing $objId failed", e) + writeInt(dos, -1) + } + case _ => dos.writeInt(-1) + } + } else { + handleMethodCall(isStatic, objId, methodName, numArgs, dis, dos) + } + + val reply = bos.toByteArray + ctx.write(reply) + } + + override def channelReadComplete(ctx: ChannelHandlerContext): Unit = { + ctx.flush() + } + + override def exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable): Unit = { + // Close the connection when an exception is raised. + cause.printStackTrace() + ctx.close() + } + + def handleMethodCall( + isStatic: Boolean, + objId: String, + methodName: String, + numArgs: Int, + dis: DataInputStream, + dos: DataOutputStream): Unit = { + var obj: Object = null + try { + val cls = if (isStatic) { + Class.forName(objId) + } else { + JVMObjectTracker.get(objId) match { + case None => throw new IllegalArgumentException("Object not found " + objId) + case Some(o) => + obj = o + o.getClass + } + } + + val args = readArgs(numArgs, dis) + + val methods = cls.getMethods + val selectedMethods = methods.filter(m => m.getName == methodName) + if (selectedMethods.length > 0) { + val methods = selectedMethods.filter { x => + matchMethod(numArgs, args, x.getParameterTypes) + } + if (methods.isEmpty) { + logWarning(s"cannot find matching method ${cls}.$methodName. " + + s"Candidates are:") + selectedMethods.foreach { method => + logWarning(s"$methodName(${method.getParameterTypes.mkString(",")})") + } + throw new Exception(s"No matched method found for $cls.$methodName") + } + val ret = methods.head.invoke(obj, args:_*) + + // Write status bit + writeInt(dos, 0) + writeObject(dos, ret.asInstanceOf[AnyRef]) + } else if (methodName == "") { + // methodName should be "" for constructor + val ctor = cls.getConstructors.filter { x => + matchMethod(numArgs, args, x.getParameterTypes) + }.head + + val obj = ctor.newInstance(args:_*) + + writeInt(dos, 0) + writeObject(dos, obj.asInstanceOf[AnyRef]) + } else { + throw new IllegalArgumentException("invalid method " + methodName + " for object " + objId) + } + } catch { + case e: Exception => + logError(s"$methodName on $objId failed", e) + writeInt(dos, -1) + } + } + + // Read a number of arguments from the data input stream + def readArgs(numArgs: Int, dis: DataInputStream): Array[java.lang.Object] = { + (0 until numArgs).map { arg => + readObject(dis) + }.toArray + } + + // Checks if the arguments passed in args matches the parameter types. + // NOTE: Currently we do exact match. We may add type conversions later. + def matchMethod( + numArgs: Int, + args: Array[java.lang.Object], + parameterTypes: Array[Class[_]]): Boolean = { + if (parameterTypes.length != numArgs) { + return false + } + + for (i <- 0 to numArgs - 1) { + val parameterType = parameterTypes(i) + var parameterWrapperType = parameterType + + // Convert native parameters to Object types as args is Array[Object] here + if (parameterType.isPrimitive) { + parameterWrapperType = parameterType match { + case java.lang.Integer.TYPE => classOf[java.lang.Integer] + case java.lang.Double.TYPE => classOf[java.lang.Double] + case java.lang.Boolean.TYPE => classOf[java.lang.Boolean] + case _ => parameterType + } + } + if (!parameterWrapperType.isInstance(args(i))) { + return false + } + } + true + } +} + +/** + * Helper singleton that tracks JVM objects returned to R. + * This is useful for referencing these objects in RPC calls. + */ +private[r] object JVMObjectTracker { + + // TODO: This map should be thread-safe if we want to support multiple + // connections at the same time + private[this] val objMap = new HashMap[String, Object] + + // TODO: We support only one connection now, so an integer is fine. + // Investigate using use atomic integer in the future. + private[this] var objCounter: Int = 0 + + def getObject(id: String): Object = { + objMap(id) + } + + def get(id: String): Option[Object] = { + objMap.get(id) + } + + def put(obj: Object): String = { + val objId = objCounter.toString + objCounter = objCounter + 1 + objMap.put(objId, obj) + objId + } + + def remove(id: String): Option[Object] = { + objMap.remove(id) + } + +} diff --git a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala new file mode 100644 index 0000000000000..5fa4d483b8342 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala @@ -0,0 +1,450 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.r + +import java.io._ +import java.net.ServerSocket +import java.util.{Map => JMap} + +import scala.collection.JavaConversions._ +import scala.io.Source +import scala.reflect.ClassTag +import scala.util.Try + +import org.apache.spark._ +import org.apache.spark.api.java.{JavaPairRDD, JavaRDD, JavaSparkContext} +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.rdd.RDD +import org.apache.spark.util.Utils + +private abstract class BaseRRDD[T: ClassTag, U: ClassTag]( + parent: RDD[T], + numPartitions: Int, + func: Array[Byte], + deserializer: String, + serializer: String, + packageNames: Array[Byte], + rLibDir: String, + broadcastVars: Array[Broadcast[Object]]) + extends RDD[U](parent) with Logging { + override def getPartitions: Array[Partition] = parent.partitions + + override def compute(partition: Partition, context: TaskContext): Iterator[U] = { + + // The parent may be also an RRDD, so we should launch it first. + val parentIterator = firstParent[T].iterator(partition, context) + + // we expect two connections + val serverSocket = new ServerSocket(0, 2) + val listenPort = serverSocket.getLocalPort() + + // The stdout/stderr is shared by multiple tasks, because we use one daemon + // to launch child process as worker. + val errThread = RRDD.createRWorker(rLibDir, listenPort) + + // We use two sockets to separate input and output, then it's easy to manage + // the lifecycle of them to avoid deadlock. + // TODO: optimize it to use one socket + + // the socket used to send out the input of task + serverSocket.setSoTimeout(10000) + val inSocket = serverSocket.accept() + startStdinThread(inSocket.getOutputStream(), parentIterator, partition.index) + + // the socket used to receive the output of task + val outSocket = serverSocket.accept() + val inputStream = new BufferedInputStream(outSocket.getInputStream) + val dataStream = openDataStream(inputStream) + serverSocket.close() + + try { + + return new Iterator[U] { + def next(): U = { + val obj = _nextObj + if (hasNext) { + _nextObj = read() + } + obj + } + + var _nextObj = read() + + def hasNext(): Boolean = { + val hasMore = (_nextObj != null) + if (!hasMore) { + dataStream.close() + } + hasMore + } + } + } catch { + case e: Exception => + throw new SparkException("R computation failed with\n " + errThread.getLines()) + } + } + + /** + * Start a thread to write RDD data to the R process. + */ + private def startStdinThread[T]( + output: OutputStream, + iter: Iterator[T], + partition: Int): Unit = { + + val env = SparkEnv.get + val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt + val stream = new BufferedOutputStream(output, bufferSize) + + new Thread("writer for R") { + override def run(): Unit = { + try { + SparkEnv.set(env) + val dataOut = new DataOutputStream(stream) + dataOut.writeInt(partition) + + SerDe.writeString(dataOut, deserializer) + SerDe.writeString(dataOut, serializer) + + dataOut.writeInt(packageNames.length) + dataOut.write(packageNames) + + dataOut.writeInt(func.length) + dataOut.write(func) + + dataOut.writeInt(broadcastVars.length) + broadcastVars.foreach { broadcast => + // TODO(shivaram): Read a Long in R to avoid this cast + dataOut.writeInt(broadcast.id.toInt) + // TODO: Pass a byte array from R to avoid this cast ? + val broadcastByteArr = broadcast.value.asInstanceOf[Array[Byte]] + dataOut.writeInt(broadcastByteArr.length) + dataOut.write(broadcastByteArr) + } + + dataOut.writeInt(numPartitions) + + if (!iter.hasNext) { + dataOut.writeInt(0) + } else { + dataOut.writeInt(1) + } + + val printOut = new PrintStream(stream) + + def writeElem(elem: Any): Unit = { + if (deserializer == SerializationFormats.BYTE) { + val elemArr = elem.asInstanceOf[Array[Byte]] + dataOut.writeInt(elemArr.length) + dataOut.write(elemArr) + } else if (deserializer == SerializationFormats.ROW) { + dataOut.write(elem.asInstanceOf[Array[Byte]]) + } else if (deserializer == SerializationFormats.STRING) { + printOut.println(elem) + } + } + + for (elem <- iter) { + elem match { + case (key, value) => + writeElem(key) + writeElem(value) + case _ => + writeElem(elem) + } + } + stream.flush() + } catch { + // TODO: We should propogate this error to the task thread + case e: Exception => + logError("R Writer thread got an exception", e) + } finally { + Try(output.close()) + } + } + }.start() + } + + protected def openDataStream(input: InputStream): Closeable + + protected def read(): U +} + +/** + * Form an RDD[(Int, Array[Byte])] from key-value pairs returned from R. + * This is used by SparkR's shuffle operations. + */ +private class PairwiseRRDD[T: ClassTag]( + parent: RDD[T], + numPartitions: Int, + hashFunc: Array[Byte], + deserializer: String, + packageNames: Array[Byte], + rLibDir: String, + broadcastVars: Array[Object]) + extends BaseRRDD[T, (Int, Array[Byte])]( + parent, numPartitions, hashFunc, deserializer, + SerializationFormats.BYTE, packageNames, rLibDir, + broadcastVars.map(x => x.asInstanceOf[Broadcast[Object]])) { + + private var dataStream: DataInputStream = _ + + override protected def openDataStream(input: InputStream): Closeable = { + dataStream = new DataInputStream(input) + dataStream + } + + override protected def read(): (Int, Array[Byte]) = { + try { + val length = dataStream.readInt() + + length match { + case length if length == 2 => + val hashedKey = dataStream.readInt() + val contentPairsLength = dataStream.readInt() + val contentPairs = new Array[Byte](contentPairsLength) + dataStream.readFully(contentPairs) + (hashedKey, contentPairs) + case _ => null // End of input + } + } catch { + case eof: EOFException => { + throw new SparkException("R worker exited unexpectedly (crashed)", eof) + } + } + } + + lazy val asJavaPairRDD : JavaPairRDD[Int, Array[Byte]] = JavaPairRDD.fromRDD(this) +} + +/** + * An RDD that stores serialized R objects as Array[Byte]. + */ +private class RRDD[T: ClassTag]( + parent: RDD[T], + func: Array[Byte], + deserializer: String, + serializer: String, + packageNames: Array[Byte], + rLibDir: String, + broadcastVars: Array[Object]) + extends BaseRRDD[T, Array[Byte]]( + parent, -1, func, deserializer, serializer, packageNames, rLibDir, + broadcastVars.map(x => x.asInstanceOf[Broadcast[Object]])) { + + private var dataStream: DataInputStream = _ + + override protected def openDataStream(input: InputStream): Closeable = { + dataStream = new DataInputStream(input) + dataStream + } + + override protected def read(): Array[Byte] = { + try { + val length = dataStream.readInt() + + length match { + case length if length > 0 => + val obj = new Array[Byte](length) + dataStream.readFully(obj, 0, length) + obj + case _ => null + } + } catch { + case eof: EOFException => { + throw new SparkException("R worker exited unexpectedly (crashed)", eof) + } + } + } + + lazy val asJavaRDD : JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this) +} + +/** + * An RDD that stores R objects as Array[String]. + */ +private class StringRRDD[T: ClassTag]( + parent: RDD[T], + func: Array[Byte], + deserializer: String, + packageNames: Array[Byte], + rLibDir: String, + broadcastVars: Array[Object]) + extends BaseRRDD[T, String]( + parent, -1, func, deserializer, SerializationFormats.STRING, packageNames, rLibDir, + broadcastVars.map(x => x.asInstanceOf[Broadcast[Object]])) { + + private var dataStream: BufferedReader = _ + + override protected def openDataStream(input: InputStream): Closeable = { + dataStream = new BufferedReader(new InputStreamReader(input)) + dataStream + } + + override protected def read(): String = { + try { + dataStream.readLine() + } catch { + case e: IOException => { + throw new SparkException("R worker exited unexpectedly (crashed)", e) + } + } + } + + lazy val asJavaRDD : JavaRDD[String] = JavaRDD.fromRDD(this) +} + +private[r] class BufferedStreamThread( + in: InputStream, + name: String, + errBufferSize: Int) extends Thread(name) with Logging { + val lines = new Array[String](errBufferSize) + var lineIdx = 0 + override def run() { + for (line <- Source.fromInputStream(in).getLines) { + synchronized { + lines(lineIdx) = line + lineIdx = (lineIdx + 1) % errBufferSize + } + logInfo(line) + } + } + + def getLines(): String = synchronized { + (0 until errBufferSize).filter { x => + lines((x + lineIdx) % errBufferSize) != null + }.map { x => + lines((x + lineIdx) % errBufferSize) + }.mkString("\n") + } +} + +private[r] object RRDD { + // Because forking processes from Java is expensive, we prefer to launch + // a single R daemon (daemon.R) and tell it to fork new workers for our tasks. + // This daemon currently only works on UNIX-based systems now, so we should + // also fall back to launching workers (worker.R) directly. + private[this] var errThread: BufferedStreamThread = _ + private[this] var daemonChannel: DataOutputStream = _ + + def createSparkContext( + master: String, + appName: String, + sparkHome: String, + jars: Array[String], + sparkEnvirMap: JMap[Object, Object], + sparkExecutorEnvMap: JMap[Object, Object]): JavaSparkContext = { + + val sparkConf = new SparkConf().setAppName(appName) + .setSparkHome(sparkHome) + .setJars(jars) + + // Override `master` if we have a user-specified value + if (master != "") { + sparkConf.setMaster(master) + } else { + // If conf has no master set it to "local" to maintain + // backwards compatibility + sparkConf.setIfMissing("spark.master", "local") + } + + for ((name, value) <- sparkEnvirMap) { + sparkConf.set(name.asInstanceOf[String], value.asInstanceOf[String]) + } + for ((name, value) <- sparkExecutorEnvMap) { + sparkConf.setExecutorEnv(name.asInstanceOf[String], value.asInstanceOf[String]) + } + + new JavaSparkContext(sparkConf) + } + + /** + * Start a thread to print the process's stderr to ours + */ + private def startStdoutThread(proc: Process): BufferedStreamThread = { + val BUFFER_SIZE = 100 + val thread = new BufferedStreamThread(proc.getInputStream, "stdout reader for R", BUFFER_SIZE) + thread.setDaemon(true) + thread.start() + thread + } + + private def createRProcess(rLibDir: String, port: Int, script: String): BufferedStreamThread = { + val rCommand = "Rscript" + val rOptions = "--vanilla" + val rExecScript = rLibDir + "/SparkR/worker/" + script + val pb = new ProcessBuilder(List(rCommand, rOptions, rExecScript)) + // Unset the R_TESTS environment variable for workers. + // This is set by R CMD check as startup.Rs + // (http://svn.r-project.org/R/trunk/src/library/tools/R/testing.R) + // and confuses worker script which tries to load a non-existent file + pb.environment().put("R_TESTS", "") + pb.environment().put("SPARKR_RLIBDIR", rLibDir) + pb.environment().put("SPARKR_WORKER_PORT", port.toString) + pb.redirectErrorStream(true) // redirect stderr into stdout + val proc = pb.start() + val errThread = startStdoutThread(proc) + errThread + } + + /** + * ProcessBuilder used to launch worker R processes. + */ + def createRWorker(rLibDir: String, port: Int): BufferedStreamThread = { + val useDaemon = SparkEnv.get.conf.getBoolean("spark.sparkr.use.daemon", true) + if (!Utils.isWindows && useDaemon) { + synchronized { + if (daemonChannel == null) { + // we expect one connections + val serverSocket = new ServerSocket(0, 1) + val daemonPort = serverSocket.getLocalPort + errThread = createRProcess(rLibDir, daemonPort, "daemon.R") + // the socket used to send out the input of task + serverSocket.setSoTimeout(10000) + val sock = serverSocket.accept() + daemonChannel = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream)) + serverSocket.close() + } + try { + daemonChannel.writeInt(port) + daemonChannel.flush() + } catch { + case e: IOException => + // daemon process died + daemonChannel.close() + daemonChannel = null + errThread = null + // fail the current task, retry by scheduler + throw e + } + errThread + } + } else { + createRProcess(rLibDir, port, "worker.R") + } + } + + /** + * Create an RRDD given a sequence of byte arrays. Used to create RRDD when `parallelize` is + * called from R. + */ + def createRDDFromArray(jsc: JavaSparkContext, arr: Array[Array[Byte]]): JavaRDD[Array[Byte]] = { + JavaRDD.fromRDD(jsc.sc.parallelize(arr, arr.length)) + } + +} diff --git a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala new file mode 100644 index 0000000000000..ccb2a371f4e48 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala @@ -0,0 +1,340 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.r + +import java.io.{DataInputStream, DataOutputStream} +import java.sql.{Date, Time} + +import scala.collection.JavaConversions._ + +/** + * Utility functions to serialize, deserialize objects to / from R + */ +private[spark] object SerDe { + + // Type mapping from R to Java + // + // NULL -> void + // integer -> Int + // character -> String + // logical -> Boolean + // double, numeric -> Double + // raw -> Array[Byte] + // Date -> Date + // POSIXlt/POSIXct -> Time + // + // list[T] -> Array[T], where T is one of above mentioned types + // environment -> Map[String, T], where T is a native type + // jobj -> Object, where jobj is an object created in the backend + + def readObjectType(dis: DataInputStream): Char = { + dis.readByte().toChar + } + + def readObject(dis: DataInputStream): Object = { + val dataType = readObjectType(dis) + readTypedObject(dis, dataType) + } + + def readTypedObject( + dis: DataInputStream, + dataType: Char): Object = { + dataType match { + case 'n' => null + case 'i' => new java.lang.Integer(readInt(dis)) + case 'd' => new java.lang.Double(readDouble(dis)) + case 'b' => new java.lang.Boolean(readBoolean(dis)) + case 'c' => readString(dis) + case 'e' => readMap(dis) + case 'r' => readBytes(dis) + case 'l' => readList(dis) + case 'D' => readDate(dis) + case 't' => readTime(dis) + case 'j' => JVMObjectTracker.getObject(readString(dis)) + case _ => throw new IllegalArgumentException(s"Invalid type $dataType") + } + } + + def readBytes(in: DataInputStream): Array[Byte] = { + val len = readInt(in) + val out = new Array[Byte](len) + val bytesRead = in.readFully(out) + out + } + + def readInt(in: DataInputStream): Int = { + in.readInt() + } + + def readDouble(in: DataInputStream): Double = { + in.readDouble() + } + + def readString(in: DataInputStream): String = { + val len = in.readInt() + val asciiBytes = new Array[Byte](len) + in.readFully(asciiBytes) + assert(asciiBytes(len - 1) == 0) + val str = new String(asciiBytes.dropRight(1).map(_.toChar)) + str + } + + def readBoolean(in: DataInputStream): Boolean = { + val intVal = in.readInt() + if (intVal == 0) false else true + } + + def readDate(in: DataInputStream): Date = { + Date.valueOf(readString(in)) + } + + def readTime(in: DataInputStream): Time = { + val t = in.readDouble() + new Time((t * 1000L).toLong) + } + + def readBytesArr(in: DataInputStream): Array[Array[Byte]] = { + val len = readInt(in) + (0 until len).map(_ => readBytes(in)).toArray + } + + def readIntArr(in: DataInputStream): Array[Int] = { + val len = readInt(in) + (0 until len).map(_ => readInt(in)).toArray + } + + def readDoubleArr(in: DataInputStream): Array[Double] = { + val len = readInt(in) + (0 until len).map(_ => readDouble(in)).toArray + } + + def readBooleanArr(in: DataInputStream): Array[Boolean] = { + val len = readInt(in) + (0 until len).map(_ => readBoolean(in)).toArray + } + + def readStringArr(in: DataInputStream): Array[String] = { + val len = readInt(in) + (0 until len).map(_ => readString(in)).toArray + } + + def readList(dis: DataInputStream): Array[_] = { + val arrType = readObjectType(dis) + arrType match { + case 'i' => readIntArr(dis) + case 'c' => readStringArr(dis) + case 'd' => readDoubleArr(dis) + case 'b' => readBooleanArr(dis) + case 'j' => readStringArr(dis).map(x => JVMObjectTracker.getObject(x)) + case 'r' => readBytesArr(dis) + case _ => throw new IllegalArgumentException(s"Invalid array type $arrType") + } + } + + def readMap(in: DataInputStream): java.util.Map[Object, Object] = { + val len = readInt(in) + if (len > 0) { + val keysType = readObjectType(in) + val keysLen = readInt(in) + val keys = (0 until keysLen).map(_ => readTypedObject(in, keysType)) + + val valuesType = readObjectType(in) + val valuesLen = readInt(in) + val values = (0 until valuesLen).map(_ => readTypedObject(in, valuesType)) + mapAsJavaMap(keys.zip(values).toMap) + } else { + new java.util.HashMap[Object, Object]() + } + } + + // Methods to write out data from Java to R + // + // Type mapping from Java to R + // + // void -> NULL + // Int -> integer + // String -> character + // Boolean -> logical + // Double -> double + // Long -> double + // Array[Byte] -> raw + // Date -> Date + // Time -> POSIXct + // + // Array[T] -> list() + // Object -> jobj + + def writeType(dos: DataOutputStream, typeStr: String): Unit = { + typeStr match { + case "void" => dos.writeByte('n') + case "character" => dos.writeByte('c') + case "double" => dos.writeByte('d') + case "integer" => dos.writeByte('i') + case "logical" => dos.writeByte('b') + case "date" => dos.writeByte('D') + case "time" => dos.writeByte('t') + case "raw" => dos.writeByte('r') + case "list" => dos.writeByte('l') + case "jobj" => dos.writeByte('j') + case _ => throw new IllegalArgumentException(s"Invalid type $typeStr") + } + } + + def writeObject(dos: DataOutputStream, value: Object): Unit = { + if (value == null) { + writeType(dos, "void") + } else { + value.getClass.getName match { + case "java.lang.String" => + writeType(dos, "character") + writeString(dos, value.asInstanceOf[String]) + case "long" | "java.lang.Long" => + writeType(dos, "double") + writeDouble(dos, value.asInstanceOf[Long].toDouble) + case "double" | "java.lang.Double" => + writeType(dos, "double") + writeDouble(dos, value.asInstanceOf[Double]) + case "int" | "java.lang.Integer" => + writeType(dos, "integer") + writeInt(dos, value.asInstanceOf[Int]) + case "boolean" | "java.lang.Boolean" => + writeType(dos, "logical") + writeBoolean(dos, value.asInstanceOf[Boolean]) + case "java.sql.Date" => + writeType(dos, "date") + writeDate(dos, value.asInstanceOf[Date]) + case "java.sql.Time" => + writeType(dos, "time") + writeTime(dos, value.asInstanceOf[Time]) + case "[B" => + writeType(dos, "raw") + writeBytes(dos, value.asInstanceOf[Array[Byte]]) + // TODO: Types not handled right now include + // byte, char, short, float + + // Handle arrays + case "[Ljava.lang.String;" => + writeType(dos, "list") + writeStringArr(dos, value.asInstanceOf[Array[String]]) + case "[I" => + writeType(dos, "list") + writeIntArr(dos, value.asInstanceOf[Array[Int]]) + case "[J" => + writeType(dos, "list") + writeDoubleArr(dos, value.asInstanceOf[Array[Long]].map(_.toDouble)) + case "[D" => + writeType(dos, "list") + writeDoubleArr(dos, value.asInstanceOf[Array[Double]]) + case "[Z" => + writeType(dos, "list") + writeBooleanArr(dos, value.asInstanceOf[Array[Boolean]]) + case "[[B" => + writeType(dos, "list") + writeBytesArr(dos, value.asInstanceOf[Array[Array[Byte]]]) + case otherName => + // Handle array of objects + if (otherName.startsWith("[L")) { + val objArr = value.asInstanceOf[Array[Object]] + writeType(dos, "list") + writeType(dos, "jobj") + dos.writeInt(objArr.length) + objArr.foreach(o => writeJObj(dos, o)) + } else { + writeType(dos, "jobj") + writeJObj(dos, value) + } + } + } + } + + def writeInt(out: DataOutputStream, value: Int): Unit = { + out.writeInt(value) + } + + def writeDouble(out: DataOutputStream, value: Double): Unit = { + out.writeDouble(value) + } + + def writeBoolean(out: DataOutputStream, value: Boolean): Unit = { + val intValue = if (value) 1 else 0 + out.writeInt(intValue) + } + + def writeDate(out: DataOutputStream, value: Date): Unit = { + writeString(out, value.toString) + } + + def writeTime(out: DataOutputStream, value: Time): Unit = { + out.writeDouble(value.getTime.toDouble / 1000.0) + } + + + // NOTE: Only works for ASCII right now + def writeString(out: DataOutputStream, value: String): Unit = { + val len = value.length + out.writeInt(len + 1) // For the \0 + out.writeBytes(value) + out.writeByte(0) + } + + def writeBytes(out: DataOutputStream, value: Array[Byte]): Unit = { + out.writeInt(value.length) + out.write(value) + } + + def writeJObj(out: DataOutputStream, value: Object): Unit = { + val objId = JVMObjectTracker.put(value) + writeString(out, objId) + } + + def writeIntArr(out: DataOutputStream, value: Array[Int]): Unit = { + writeType(out, "integer") + out.writeInt(value.length) + value.foreach(v => out.writeInt(v)) + } + + def writeDoubleArr(out: DataOutputStream, value: Array[Double]): Unit = { + writeType(out, "double") + out.writeInt(value.length) + value.foreach(v => out.writeDouble(v)) + } + + def writeBooleanArr(out: DataOutputStream, value: Array[Boolean]): Unit = { + writeType(out, "logical") + out.writeInt(value.length) + value.foreach(v => writeBoolean(out, v)) + } + + def writeStringArr(out: DataOutputStream, value: Array[String]): Unit = { + writeType(out, "character") + out.writeInt(value.length) + value.foreach(v => writeString(out, v)) + } + + def writeBytesArr(out: DataOutputStream, value: Array[Array[Byte]]): Unit = { + writeType(out, "raw") + out.writeInt(value.length) + value.foreach(v => writeBytes(out, v)) + } +} + +private[r] object SerializationFormats { + val BYTE = "byte" + val STRING = "string" + val ROW = "row" +} diff --git a/core/src/main/scala/org/apache/spark/deploy/RRunner.scala b/core/src/main/scala/org/apache/spark/deploy/RRunner.scala new file mode 100644 index 0000000000000..e99779f299785 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/RRunner.scala @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy + +import java.io._ +import java.util.concurrent.{Semaphore, TimeUnit} + +import scala.collection.JavaConversions._ + +import org.apache.hadoop.fs.Path + +import org.apache.spark.api.r.RBackend +import org.apache.spark.util.RedirectThread + +/** + * Main class used to launch SparkR applications using spark-submit. It executes R as a + * subprocess and then has it connect back to the JVM to access system properties etc. + */ +object RRunner { + def main(args: Array[String]): Unit = { + val rFile = PythonRunner.formatPath(args(0)) + + val otherArgs = args.slice(1, args.length) + + // Time to wait for SparkR backend to initialize in seconds + val backendTimeout = sys.env.getOrElse("SPARKR_BACKEND_TIMEOUT", "120").toInt + val rCommand = "Rscript" + + // Check if the file path exists. + // If not, change directory to current working directory for YARN cluster mode + val rF = new File(rFile) + val rFileNormalized = if (!rF.exists()) { + new Path(rFile).getName + } else { + rFile + } + + // Launch a SparkR backend server for the R process to connect to; this will let it see our + // Java system properties etc. + val sparkRBackend = new RBackend() + @volatile var sparkRBackendPort = 0 + val initialized = new Semaphore(0) + val sparkRBackendThread = new Thread("SparkR backend") { + override def run() { + sparkRBackendPort = sparkRBackend.init() + initialized.release() + sparkRBackend.run() + } + } + + sparkRBackendThread.start() + // Wait for RBackend initialization to finish + if (initialized.tryAcquire(backendTimeout, TimeUnit.SECONDS)) { + // Launch R + val returnCode = try { + val builder = new ProcessBuilder(Seq(rCommand, rFileNormalized) ++ otherArgs) + val env = builder.environment() + env.put("EXISTING_SPARKR_BACKEND_PORT", sparkRBackendPort.toString) + val sparkHome = System.getenv("SPARK_HOME") + env.put("R_PROFILE_USER", + Seq(sparkHome, "R", "lib", "SparkR", "profile", "general.R").mkString(File.separator)) + builder.redirectErrorStream(true) // Ugly but needed for stdout and stderr to synchronize + val process = builder.start() + + new RedirectThread(process.getInputStream, System.out, "redirect R output").start() + + process.waitFor() + } finally { + sparkRBackend.close() + } + System.exit(returnCode) + } else { + System.err.println("SparkR backend did not initialize in " + backendTimeout + " seconds") + System.exit(-1) + } + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 660307d19eab4..60bc243ebf40a 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -77,6 +77,7 @@ object SparkSubmit { // Special primary resource names that represent shells rather than application jars. private val SPARK_SHELL = "spark-shell" private val PYSPARK_SHELL = "pyspark-shell" + private val SPARKR_SHELL = "sparkr-shell" private val CLASS_NOT_FOUND_EXIT_STATUS = 101 @@ -284,6 +285,13 @@ object SparkSubmit { } } + // Require all R files to be local + if (args.isR && !isYarnCluster) { + if (Utils.nonLocalPaths(args.primaryResource).nonEmpty) { + printErrorAndExit(s"Only local R files are supported: $args.primaryResource") + } + } + // The following modes are not supported or applicable (clusterManager, deployMode) match { case (MESOS, CLUSTER) => @@ -291,6 +299,9 @@ object SparkSubmit { case (STANDALONE, CLUSTER) if args.isPython => printErrorAndExit("Cluster deploy mode is currently not supported for python " + "applications on standalone clusters.") + case (STANDALONE, CLUSTER) if args.isR => + printErrorAndExit("Cluster deploy mode is currently not supported for R " + + "applications on standalone clusters.") case (_, CLUSTER) if isShell(args.primaryResource) => printErrorAndExit("Cluster deploy mode is not applicable to Spark shells.") case (_, CLUSTER) if isSqlShell(args.mainClass) => @@ -317,11 +328,32 @@ object SparkSubmit { } } - // In yarn-cluster mode for a python app, add primary resource and pyFiles to files - // that can be distributed with the job - if (args.isPython && isYarnCluster) { - args.files = mergeFileLists(args.files, args.primaryResource) - args.files = mergeFileLists(args.files, args.pyFiles) + // If we're running a R app, set the main class to our specific R runner + if (args.isR && deployMode == CLIENT) { + if (args.primaryResource == SPARKR_SHELL) { + args.mainClass = "org.apache.spark.api.r.RBackend" + } else { + // If a R file is provided, add it to the child arguments and list of files to deploy. + // Usage: RRunner
[app arguments] + args.mainClass = "org.apache.spark.deploy.RRunner" + args.childArgs = ArrayBuffer(args.primaryResource) ++ args.childArgs + args.files = mergeFileLists(args.files, args.primaryResource) + } + } + + if (isYarnCluster) { + // In yarn-cluster mode for a python app, add primary resource and pyFiles to files + // that can be distributed with the job + if (args.isPython) { + args.files = mergeFileLists(args.files, args.primaryResource) + args.files = mergeFileLists(args.files, args.pyFiles) + } + + // In yarn-cluster mode for a R app, add primary resource to files + // that can be distributed with the job + if (args.isR) { + args.files = mergeFileLists(args.files, args.primaryResource) + } } // Special flag to avoid deprecation warnings at the client @@ -405,8 +437,8 @@ object SparkSubmit { // Add the application jar automatically so the user doesn't have to call sc.addJar // For YARN cluster mode, the jar is already distributed on each node as "app.jar" - // For python files, the primary resource is already distributed as a regular file - if (!isYarnCluster && !args.isPython) { + // For python and R files, the primary resource is already distributed as a regular file + if (!isYarnCluster && !args.isPython && !args.isR) { var jars = sysProps.get("spark.jars").map(x => x.split(",").toSeq).getOrElse(Seq.empty) if (isUserJar(args.primaryResource)) { jars = jars ++ Seq(args.primaryResource) @@ -447,6 +479,10 @@ object SparkSubmit { childArgs += ("--py-files", pyFilesNames) } childArgs += ("--class", "org.apache.spark.deploy.PythonRunner") + } else if (args.isR) { + val mainFile = new Path(args.primaryResource).getName + childArgs += ("--primary-r-file", mainFile) + childArgs += ("--class", "org.apache.spark.deploy.RRunner") } else { if (args.primaryResource != SPARK_INTERNAL) { childArgs += ("--jar", args.primaryResource) @@ -591,15 +627,15 @@ object SparkSubmit { /** * Return whether the given primary resource represents a user jar. */ - private def isUserJar(primaryResource: String): Boolean = { - !isShell(primaryResource) && !isPython(primaryResource) && !isInternal(primaryResource) + private[deploy] def isUserJar(res: String): Boolean = { + !isShell(res) && !isPython(res) && !isInternal(res) && !isR(res) } /** * Return whether the given primary resource represents a shell. */ - private[deploy] def isShell(primaryResource: String): Boolean = { - primaryResource == SPARK_SHELL || primaryResource == PYSPARK_SHELL + private[deploy] def isShell(res: String): Boolean = { + (res == SPARK_SHELL || res == PYSPARK_SHELL || res == SPARKR_SHELL) } /** @@ -619,12 +655,19 @@ object SparkSubmit { /** * Return whether the given primary resource requires running python. */ - private[deploy] def isPython(primaryResource: String): Boolean = { - primaryResource.endsWith(".py") || primaryResource == PYSPARK_SHELL + private[deploy] def isPython(res: String): Boolean = { + res != null && res.endsWith(".py") || res == PYSPARK_SHELL + } + + /** + * Return whether the given primary resource requires running R. + */ + private[deploy] def isR(res: String): Boolean = { + res != null && res.endsWith(".R") || res == SPARKR_SHELL } - private[deploy] def isInternal(primaryResource: String): Boolean = { - primaryResource == SPARK_INTERNAL + private[deploy] def isInternal(res: String): Boolean = { + res == SPARK_INTERNAL } /** diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index 6eb73c43470a5..03ecf3fd99ec5 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -59,6 +59,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S var verbose: Boolean = false var isPython: Boolean = false var pyFiles: String = null + var isR: Boolean = false var action: SparkSubmitAction = null val sparkProperties: HashMap[String, String] = new HashMap[String, String]() var proxyUser: String = null @@ -158,7 +159,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S .getOrElse(sparkProperties.get("spark.executor.instances").orNull) // Try to set main class from JAR if no --class argument is given - if (mainClass == null && !isPython && primaryResource != null) { + if (mainClass == null && !isPython && !isR && primaryResource != null) { val uri = new URI(primaryResource) val uriScheme = uri.getScheme() @@ -211,9 +212,9 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S printUsageAndExit(-1) } if (primaryResource == null) { - SparkSubmit.printErrorAndExit("Must specify a primary resource (JAR or Python file)") + SparkSubmit.printErrorAndExit("Must specify a primary resource (JAR or Python or R file)") } - if (mainClass == null && !isPython) { + if (mainClass == null && SparkSubmit.isUserJar(primaryResource)) { SparkSubmit.printErrorAndExit("No main class set in JAR; please specify one with --class") } if (pyFiles != null && !isPython) { @@ -414,6 +415,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S opt } isPython = SparkSubmit.isPython(opt) + isR = SparkSubmit.isR(opt) false } diff --git a/core/src/test/resources/log4j.properties b/core/src/test/resources/log4j.properties index 287c8e3563503..eb3b1999eb996 100644 --- a/core/src/test/resources/log4j.properties +++ b/core/src/test/resources/log4j.properties @@ -24,5 +24,5 @@ log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.eclipse.jetty=WARN -org.eclipse.jetty.LEVEL=WARN +log4j.logger.org.spark-project.jetty=WARN +org.spark-project.jetty.LEVEL=WARN diff --git a/dev/run-tests b/dev/run-tests index 561d7fc9e7b1f..1b6cf78b5da01 100755 --- a/dev/run-tests +++ b/dev/run-tests @@ -236,3 +236,18 @@ echo "=========================================================================" CURRENT_BLOCK=$BLOCK_PYSPARK_UNIT_TESTS ./python/run-tests + +echo "" +echo "=========================================================================" +echo "Running SparkR tests" +echo "=========================================================================" + +CURRENT_BLOCK=$BLOCK_SPARKR_UNIT_TESTS + +if [ $(command -v R) ]; then + ./R/install-dev.sh + ./R/run-tests.sh +else + echo "Ignoring SparkR tests as R was not found in PATH" +fi + diff --git a/dev/run-tests-codes.sh b/dev/run-tests-codes.sh index 8ab6db6925d6e..154e01255b2ef 100644 --- a/dev/run-tests-codes.sh +++ b/dev/run-tests-codes.sh @@ -25,3 +25,4 @@ readonly BLOCK_BUILD=14 readonly BLOCK_MIMA=15 readonly BLOCK_SPARK_UNIT_TESTS=16 readonly BLOCK_PYSPARK_UNIT_TESTS=17 +readonly BLOCK_SPARKR_UNIT_TESTS=18 diff --git a/dev/run-tests-jenkins b/dev/run-tests-jenkins index f10aa6b59e1af..f6372835a6dbf 100755 --- a/dev/run-tests-jenkins +++ b/dev/run-tests-jenkins @@ -210,6 +210,8 @@ done failing_test="Spark unit tests" elif [ "$test_result" -eq "$BLOCK_PYSPARK_UNIT_TESTS" ]; then failing_test="PySpark unit tests" + elif [ "$test_result" -eq "$BLOCK_SPARKR_UNIT_TESTS" ]; then + failing_test="SparkR unit tests" else failing_test="some tests" fi diff --git a/docs/README.md b/docs/README.md index 3773ea25c8b67..5852f972a051d 100644 --- a/docs/README.md +++ b/docs/README.md @@ -58,13 +58,19 @@ phase, use the following sytax: We use Sphinx to generate Python API docs, so you will need to install it by running `sudo pip install sphinx`. -## API Docs (Scaladoc and Sphinx) +## knitr, devtools + +SparkR documentation is written using `roxygen2` and we use `knitr`, `devtools` to generate +documentation. To install these packages you can run `install.packages(c("knitr", "devtools"))` from a +R console. + +## API Docs (Scaladoc, Sphinx, roxygen2) You can build just the Spark scaladoc by running `build/sbt unidoc` from the SPARK_PROJECT_ROOT directory. Similarly, you can build just the PySpark docs by running `make html` from the SPARK_PROJECT_ROOT/python/docs directory. Documentation is only generated for classes that are listed as -public in `__init__.py`. +public in `__init__.py`. The SparkR docs can be built by running SPARK_PROJECT_ROOT/R/create-docs.sh. When you run `jekyll` in the `docs` directory, it will also copy over the scaladoc for the various Spark subprojects into the `docs` directory (and then also into the `_site` directory). We use a @@ -72,5 +78,5 @@ jekyll plugin to run `build/sbt unidoc` before building the site so if you haven may take some time as it generates all of the scaladoc. The jekyll plugin also generates the PySpark docs [Sphinx](http://sphinx-doc.org/). -NOTE: To skip the step of building and copying over the Scala and Python API docs, run `SKIP_API=1 +NOTE: To skip the step of building and copying over the Scala, Python, R API docs, run `SKIP_API=1 jekyll`. diff --git a/docs/_layouts/global.html b/docs/_layouts/global.html index 2e88b3093652d..b92c75f90b11c 100755 --- a/docs/_layouts/global.html +++ b/docs/_layouts/global.html @@ -84,6 +84,7 @@
  • Scala
  • Java
  • Python
  • +
  • R
  • diff --git a/docs/_plugins/copy_api_dirs.rb b/docs/_plugins/copy_api_dirs.rb index 3c626a0b7f54b..0ea3f8eab461b 100644 --- a/docs/_plugins/copy_api_dirs.rb +++ b/docs/_plugins/copy_api_dirs.rb @@ -78,5 +78,18 @@ puts "cp -r python/docs/_build/html/. docs/api/python" cp_r("python/docs/_build/html/.", "docs/api/python") - cd("..") + # Build SparkR API docs + puts "Moving to R directory and building roxygen docs." + cd("R") + puts `./create-docs.sh` + + puts "Moving back into home dir." + cd("../") + + puts "Making directory api/R" + mkdir_p "docs/api/R" + + puts "cp -r R/pkg/html/. docs/api/R" + cp_r("R/pkg/html/.", "docs/api/R") + end diff --git a/docs/cluster-overview.md b/docs/cluster-overview.md index 6a75d5c457f02..7079de546e2f5 100644 --- a/docs/cluster-overview.md +++ b/docs/cluster-overview.md @@ -33,7 +33,11 @@ There are several useful things to note about this architecture: 2. Spark is agnostic to the underlying cluster manager. As long as it can acquire executor processes, and these communicate with each other, it is relatively easy to run it even on a cluster manager that also supports other applications (e.g. Mesos/YARN). -3. Because the driver schedules tasks on the cluster, it should be run close to the worker +3. The driver program must listen for and accept incoming connections from its executors throughout + its lifetime (e.g., see [spark.driver.port and spark.fileserver.port in the network config + section](configuration.html#networking)). As such, the driver program must be network + addressable from the worker nodes. +4. Because the driver schedules tasks on the cluster, it should be run close to the worker nodes, preferably on the same local area network. If you'd like to send requests to the cluster remotely, it's better to open an RPC to the driver and have it submit operations from nearby than to run a driver far away from the worker nodes. diff --git a/docs/img/cluster-overview.png b/docs/img/cluster-overview.png index 368274068e754..317554c5f2a5b 100644 Binary files a/docs/img/cluster-overview.png and b/docs/img/cluster-overview.png differ diff --git a/docs/img/cluster-overview.pptx b/docs/img/cluster-overview.pptx index af3c462cd904d..1b90d7ec5a7ae 100644 Binary files a/docs/img/cluster-overview.pptx and b/docs/img/cluster-overview.pptx differ diff --git a/examples/pom.xml b/examples/pom.xml index 7e93f0eec0b91..afd7c6d52f0dd 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -90,6 +90,12 @@ org.apache.spark spark-streaming-zeromq_${scala.binary.version} ${project.version} + + + org.spark-project.protobuf + protobuf-java + + org.apache.hbase @@ -234,6 +240,7 @@ org.apache.commons commons-math3 + provided com.twitter @@ -262,6 +269,22 @@ com.ning compress-lzf + + commons-cli + commons-cli + + + commons-codec + commons-codec + + + commons-lang + commons-lang + + + commons-logging + commons-logging + io.netty netty @@ -270,10 +293,22 @@ jline jline + + net.jpountz.lz4 + lz4 + org.apache.cassandra.deps avro + + org.apache.commons + commons-math3 + + + org.apache.thrift + libthrift + @@ -281,6 +316,17 @@ scopt_${scala.binary.version} 3.2.0 + + + + org.scala-lang + scala-library + provided + + @@ -322,12 +368,6 @@ - - - org.apache.commons.math3 - org.spark-project.commons.math3 - - diff --git a/examples/src/main/r/kmeans.R b/examples/src/main/r/kmeans.R new file mode 100644 index 0000000000000..6e6b5cb93789c --- /dev/null +++ b/examples/src/main/r/kmeans.R @@ -0,0 +1,93 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +library(SparkR) + +# Logistic regression in Spark. +# Note: unlike the example in Scala, a point here is represented as a vector of +# doubles. + +parseVectors <- function(lines) { + lines <- strsplit(as.character(lines) , " ", fixed = TRUE) + list(matrix(as.numeric(unlist(lines)), ncol = length(lines[[1]]))) +} + +dist.fun <- function(P, C) { + apply( + C, + 1, + function(x) { + colSums((t(P) - x)^2) + } + ) +} + +closestPoint <- function(P, C) { + max.col(-dist.fun(P, C)) +} +# Main program + +args <- commandArgs(trailing = TRUE) + +if (length(args) != 3) { + print("Usage: kmeans ") + q("no") +} + +sc <- sparkR.init(appName = "RKMeans") +K <- as.integer(args[[2]]) +convergeDist <- as.double(args[[3]]) + +lines <- textFile(sc, args[[1]]) +points <- cache(lapplyPartition(lines, parseVectors)) +# kPoints <- take(points, K) +kPoints <- do.call(rbind, takeSample(points, FALSE, K, 16189L)) +tempDist <- 1.0 + +while (tempDist > convergeDist) { + closest <- lapplyPartition( + lapply(points, + function(p) { + cp <- closestPoint(p, kPoints); + mapply(list, unique(cp), split.data.frame(cbind(1, p), cp), SIMPLIFY=FALSE) + }), + function(x) {do.call(c, x) + }) + + pointStats <- reduceByKey(closest, + function(p1, p2) { + t(colSums(rbind(p1, p2))) + }, + 2L) + + newPoints <- do.call( + rbind, + collect(lapply(pointStats, + function(tup) { + point.sum <- tup[[2]][, -1] + point.count <- tup[[2]][, 1] + point.sum/point.count + }))) + + D <- dist.fun(kPoints, newPoints) + tempDist <- sum(D[cbind(1:3, max.col(-D))]) + kPoints <- newPoints + cat("Finished iteration (delta = ", tempDist, ")\n") +} + +cat("Final centers:\n") +writeLines(unlist(lapply(kPoints, paste, collapse = " "))) diff --git a/examples/src/main/r/linear_solver_mnist.R b/examples/src/main/r/linear_solver_mnist.R new file mode 100644 index 0000000000000..c864a4232d010 --- /dev/null +++ b/examples/src/main/r/linear_solver_mnist.R @@ -0,0 +1,107 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Instructions: https://github.com/amplab-extras/SparkR-pkg/wiki/SparkR-Example:-Digit-Recognition-on-EC2 + +library(SparkR) +library(Matrix) + +args <- commandArgs(trailing = TRUE) + +# number of random features; default to 1100 +D <- ifelse(length(args) > 0, as.integer(args[[1]]), 1100) +# number of partitions for training dataset +trainParts <- 12 +# dimension of digits +d <- 784 +# number of test examples +NTrain <- 60000 +# number of training examples +NTest <- 10000 +# scale of features +gamma <- 4e-4 + +sc <- sparkR.init(appName = "SparkR-LinearSolver") + +# You can also use HDFS path to speed things up: +# hdfs:///train-mnist-dense-with-labels.data +file <- textFile(sc, "/data/train-mnist-dense-with-labels.data", trainParts) + +W <- gamma * matrix(nrow=D, ncol=d, data=rnorm(D*d)) +b <- 2 * pi * matrix(nrow=D, ncol=1, data=runif(D)) +broadcastW <- broadcast(sc, W) +broadcastB <- broadcast(sc, b) + +includePackage(sc, Matrix) +numericLines <- lapplyPartitionsWithIndex(file, + function(split, part) { + matList <- sapply(part, function(line) { + as.numeric(strsplit(line, ",", fixed=TRUE)[[1]]) + }, simplify=FALSE) + mat <- Matrix(ncol=d+1, data=unlist(matList, F, F), + sparse=T, byrow=T) + mat + }) + +featureLabels <- cache(lapplyPartition( + numericLines, + function(part) { + label <- part[,1] + mat <- part[,-1] + ones <- rep(1, nrow(mat)) + features <- cos( + mat %*% t(value(broadcastW)) + (matrix(ncol=1, data=ones) %*% t(value(broadcastB)))) + onesMat <- Matrix(ones) + featuresPlus <- cBind(features, onesMat) + labels <- matrix(nrow=nrow(mat), ncol=10, data=-1) + for (i in 1:nrow(mat)) { + labels[i, label[i]] <- 1 + } + list(label=labels, features=featuresPlus) + })) + +FTF <- Reduce("+", collect(lapplyPartition(featureLabels, + function(part) { + t(part$features) %*% part$features + }), flatten=F)) + +FTY <- Reduce("+", collect(lapplyPartition(featureLabels, + function(part) { + t(part$features) %*% part$label + }), flatten=F)) + +# solve for the coefficient matrix +C <- solve(FTF, FTY) + +test <- Matrix(as.matrix(read.csv("/data/test-mnist-dense-with-labels.data", + header=F), sparse=T)) +testData <- test[,-1] +testLabels <- matrix(ncol=1, test[,1]) + +err <- 0 + +# contstruct the feature maps for all examples from this digit +featuresTest <- cos(testData %*% t(value(broadcastW)) + + (matrix(ncol=1, data=rep(1, NTest)) %*% t(value(broadcastB)))) +featuresTest <- cBind(featuresTest, Matrix(rep(1, NTest))) + +# extract the one vs. all assignment +results <- featuresTest %*% C +labelsGot <- apply(results, 1, which.max) +err <- sum(testLabels != labelsGot) / nrow(testLabels) + +cat("\nFinished running. The error rate is: ", err, ".\n") diff --git a/examples/src/main/r/logistic_regression.R b/examples/src/main/r/logistic_regression.R new file mode 100644 index 0000000000000..2a86aa98160d3 --- /dev/null +++ b/examples/src/main/r/logistic_regression.R @@ -0,0 +1,62 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +library(SparkR) + +args <- commandArgs(trailing = TRUE) + +if (length(args) != 3) { + print("Usage: logistic_regression ") + q("no") +} + +# Initialize Spark context +sc <- sparkR.init(appName = "LogisticRegressionR") +iterations <- as.integer(args[[2]]) +D <- as.integer(args[[3]]) + +readPartition <- function(part){ + part = strsplit(part, " ", fixed = T) + list(matrix(as.numeric(unlist(part)), ncol = length(part[[1]]))) +} + +# Read data points and convert each partition to a matrix +points <- cache(lapplyPartition(textFile(sc, args[[1]]), readPartition)) + +# Initialize w to a random value +w <- runif(n=D, min = -1, max = 1) +cat("Initial w: ", w, "\n") + +# Compute logistic regression gradient for a matrix of data points +gradient <- function(partition) { + partition = partition[[1]] + Y <- partition[, 1] # point labels (first column of input file) + X <- partition[, -1] # point coordinates + + # For each point (x, y), compute gradient function + dot <- X %*% w + logit <- 1 / (1 + exp(-Y * dot)) + grad <- t(X) %*% ((logit - 1) * Y) + list(grad) +} + +for (i in 1:iterations) { + cat("On iteration ", i, "\n") + w <- w - reduce(lapplyPartition(points, gradient), "+") +} + +cat("Final w: ", w, "\n") diff --git a/examples/src/main/r/pi.R b/examples/src/main/r/pi.R new file mode 100644 index 0000000000000..aa7a833e147a0 --- /dev/null +++ b/examples/src/main/r/pi.R @@ -0,0 +1,46 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +library(SparkR) + +args <- commandArgs(trailing = TRUE) + +sc <- sparkR.init(appName = "PiR") + +slices <- ifelse(length(args) > 1, as.integer(args[[2]]), 2) + +n <- 100000 * slices + +piFunc <- function(elem) { + rands <- runif(n = 2, min = -1, max = 1) + val <- ifelse((rands[1]^2 + rands[2]^2) < 1, 1.0, 0.0) + val +} + + +piFuncVec <- function(elems) { + message(length(elems)) + rands1 <- runif(n = length(elems), min = -1, max = 1) + rands2 <- runif(n = length(elems), min = -1, max = 1) + val <- ifelse((rands1^2 + rands2^2) < 1, 1.0, 0.0) + sum(val) +} + +rdd <- parallelize(sc, 1:n, slices) +count <- reduce(lapplyPartition(rdd, piFuncVec), sum) +cat("Pi is roughly", 4.0 * count / n, "\n") +cat("Num elements in RDD ", count(rdd), "\n") diff --git a/examples/src/main/r/wordcount.R b/examples/src/main/r/wordcount.R new file mode 100644 index 0000000000000..b734cb0ecf55b --- /dev/null +++ b/examples/src/main/r/wordcount.R @@ -0,0 +1,42 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +library(SparkR) + +args <- commandArgs(trailing = TRUE) + +if (length(args) != 1) { + print("Usage: wordcount ") + q("no") +} + +# Initialize Spark context +sc <- sparkR.init(appName = "RwordCount") +lines <- textFile(sc, args[[1]]) + +words <- flatMap(lines, + function(line) { + strsplit(line, " ")[[1]] + }) +wordCount <- lapply(words, function(word) { list(word, 1L) }) + +counts <- reduceByKey(wordCount, "+", 2L) +output <- collect(counts) + +for (wordcount in output) { + cat(wordcount[[1]], ": ", wordcount[[2]], "\n") +} diff --git a/external/flume-sink/src/test/resources/log4j.properties b/external/flume-sink/src/test/resources/log4j.properties index 2a58e99817224..42df8792f147f 100644 --- a/external/flume-sink/src/test/resources/log4j.properties +++ b/external/flume-sink/src/test/resources/log4j.properties @@ -24,5 +24,5 @@ log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.eclipse.jetty=WARN +log4j.logger.org.spark-project.jetty=WARN diff --git a/external/flume/src/test/resources/log4j.properties b/external/flume/src/test/resources/log4j.properties index 9697237bfa1a3..75e3b53a093f6 100644 --- a/external/flume/src/test/resources/log4j.properties +++ b/external/flume/src/test/resources/log4j.properties @@ -24,5 +24,5 @@ log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.eclipse.jetty=WARN +log4j.logger.org.spark-project.jetty=WARN diff --git a/external/kafka/src/test/resources/log4j.properties b/external/kafka/src/test/resources/log4j.properties index 9697237bfa1a3..75e3b53a093f6 100644 --- a/external/kafka/src/test/resources/log4j.properties +++ b/external/kafka/src/test/resources/log4j.properties @@ -24,5 +24,5 @@ log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.eclipse.jetty=WARN +log4j.logger.org.spark-project.jetty=WARN diff --git a/external/mqtt/src/test/resources/log4j.properties b/external/mqtt/src/test/resources/log4j.properties index 9697237bfa1a3..75e3b53a093f6 100644 --- a/external/mqtt/src/test/resources/log4j.properties +++ b/external/mqtt/src/test/resources/log4j.properties @@ -24,5 +24,5 @@ log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.eclipse.jetty=WARN +log4j.logger.org.spark-project.jetty=WARN diff --git a/external/twitter/src/test/resources/log4j.properties b/external/twitter/src/test/resources/log4j.properties index 64bfc5745088f..9a3569789d2e0 100644 --- a/external/twitter/src/test/resources/log4j.properties +++ b/external/twitter/src/test/resources/log4j.properties @@ -24,5 +24,5 @@ log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.eclipse.jetty=WARN +log4j.logger.org.spark-project.jetty=WARN diff --git a/external/zeromq/src/test/resources/log4j.properties b/external/zeromq/src/test/resources/log4j.properties index 9697237bfa1a3..75e3b53a093f6 100644 --- a/external/zeromq/src/test/resources/log4j.properties +++ b/external/zeromq/src/test/resources/log4j.properties @@ -24,5 +24,5 @@ log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.eclipse.jetty=WARN +log4j.logger.org.spark-project.jetty=WARN diff --git a/extras/java8-tests/src/test/resources/log4j.properties b/extras/java8-tests/src/test/resources/log4j.properties index 287c8e3563503..eb3b1999eb996 100644 --- a/extras/java8-tests/src/test/resources/log4j.properties +++ b/extras/java8-tests/src/test/resources/log4j.properties @@ -24,5 +24,5 @@ log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.eclipse.jetty=WARN -org.eclipse.jetty.LEVEL=WARN +log4j.logger.org.spark-project.jetty=WARN +org.spark-project.jetty.LEVEL=WARN diff --git a/extras/kinesis-asl/src/main/resources/log4j.properties b/extras/kinesis-asl/src/main/resources/log4j.properties index 97348fb5b6123..6cdc9286c5d76 100644 --- a/extras/kinesis-asl/src/main/resources/log4j.properties +++ b/extras/kinesis-asl/src/main/resources/log4j.properties @@ -31,7 +31,7 @@ log4j.appender.console.layout=org.apache.log4j.PatternLayout log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n # Settings to quiet third party logs that are too verbose -log4j.logger.org.eclipse.jetty=WARN -log4j.logger.org.eclipse.jetty.util.component.AbstractLifeCycle=ERROR +log4j.logger.org.spark-project.jetty=WARN +log4j.logger.org.spark-project.jetty.util.component.AbstractLifeCycle=ERROR log4j.logger.org.apache.spark.repl.SparkIMain$exprTyper=INFO log4j.logger.org.apache.spark.repl.SparkILoop$SparkILoopInterpreter=INFO \ No newline at end of file diff --git a/extras/kinesis-asl/src/test/resources/log4j.properties b/extras/kinesis-asl/src/test/resources/log4j.properties index 853ef0ed2986f..edbecdae92096 100644 --- a/extras/kinesis-asl/src/test/resources/log4j.properties +++ b/extras/kinesis-asl/src/test/resources/log4j.properties @@ -24,4 +24,4 @@ log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.eclipse.jetty=WARN +log4j.logger.org.spark-project.jetty=WARN diff --git a/graphx/src/test/resources/log4j.properties b/graphx/src/test/resources/log4j.properties index 287c8e3563503..eb3b1999eb996 100644 --- a/graphx/src/test/resources/log4j.properties +++ b/graphx/src/test/resources/log4j.properties @@ -24,5 +24,5 @@ log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.eclipse.jetty=WARN -org.eclipse.jetty.LEVEL=WARN +log4j.logger.org.spark-project.jetty=WARN +org.spark-project.jetty.LEVEL=WARN diff --git a/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java b/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java index 9b04732afee14..f4ebc25bdd32b 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java +++ b/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java @@ -274,14 +274,14 @@ static String quoteForBatchScript(String arg) { } /** - * Quotes a string so that it can be used in a command string and be parsed back into a single - * argument by python's "shlex.split()" function. - * + * Quotes a string so that it can be used in a command string. * Basically, just add simple escapes. E.g.: * original single argument : ab "cd" ef * after: "ab \"cd\" ef" + * + * This can be parsed back into a single argument by python's "shlex.split()" function. */ - static String quoteForPython(String s) { + static String quoteForCommandString(String s) { StringBuilder quoted = new StringBuilder().append('"'); for (int i = 0; i < s.length(); i++) { int cp = s.codePointAt(i); diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java index 91dcf70f105db..a73c9c87e3126 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java @@ -17,14 +17,9 @@ package org.apache.spark.launcher; +import java.io.File; import java.io.IOException; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collections; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Properties; +import java.util.*; import static org.apache.spark.launcher.CommandBuilderUtils.*; @@ -53,6 +48,20 @@ class SparkSubmitCommandBuilder extends AbstractCommandBuilder { */ static final String PYSPARK_SHELL_RESOURCE = "pyspark-shell"; + /** + * Name of the app resource used to identify the SparkR shell. The command line parser expects + * the resource name to be the very first argument to spark-submit in this case. + * + * NOTE: this cannot be "sparkr-shell" since that identifies the SparkR shell to SparkSubmit + * (see sparkR.R), and can cause this code to enter into an infinite loop. + */ + static final String SPARKR_SHELL = "sparkr-shell-main"; + + /** + * This is the actual resource name that identifies the SparkR shell to SparkSubmit. + */ + static final String SPARKR_SHELL_RESOURCE = "sparkr-shell"; + /** * This map must match the class names for available special classes, since this modifies the way * command line parsing works. This maps the class name to the resource to use when calling @@ -87,6 +96,10 @@ class SparkSubmitCommandBuilder extends AbstractCommandBuilder { this.allowsMixedArguments = true; appResource = PYSPARK_SHELL_RESOURCE; submitArgs = args.subList(1, args.size()); + } else if (args.size() > 0 && args.get(0).equals(SPARKR_SHELL)) { + this.allowsMixedArguments = true; + appResource = SPARKR_SHELL_RESOURCE; + submitArgs = args.subList(1, args.size()); } else { this.allowsMixedArguments = false; } @@ -98,6 +111,8 @@ class SparkSubmitCommandBuilder extends AbstractCommandBuilder { public List buildCommand(Map env) throws IOException { if (PYSPARK_SHELL_RESOURCE.equals(appResource)) { return buildPySparkShellCommand(env); + } else if (SPARKR_SHELL_RESOURCE.equals(appResource)) { + return buildSparkRCommand(env); } else { return buildSparkSubmitCommand(env); } @@ -213,36 +228,62 @@ private List buildPySparkShellCommand(Map env) throws IO return buildCommand(env); } - // When launching the pyspark shell, the spark-submit arguments should be stored in the - // PYSPARK_SUBMIT_ARGS env variable. The executable is the PYSPARK_DRIVER_PYTHON env variable - // set by the pyspark script, followed by PYSPARK_DRIVER_PYTHON_OPTS. checkArgument(appArgs.isEmpty(), "pyspark does not support any application options."); + // When launching the pyspark shell, the spark-submit arguments should be stored in the + // PYSPARK_SUBMIT_ARGS env variable. + constructEnvVarArgs(env, "PYSPARK_SUBMIT_ARGS"); + + // The executable is the PYSPARK_DRIVER_PYTHON env variable set by the pyspark script, + // followed by PYSPARK_DRIVER_PYTHON_OPTS. + List pyargs = new ArrayList(); + pyargs.add(firstNonEmpty(System.getenv("PYSPARK_DRIVER_PYTHON"), "python")); + String pyOpts = System.getenv("PYSPARK_DRIVER_PYTHON_OPTS"); + if (!isEmpty(pyOpts)) { + pyargs.addAll(parseOptionString(pyOpts)); + } + + return pyargs; + } + + private List buildSparkRCommand(Map env) throws IOException { + if (!appArgs.isEmpty() && appArgs.get(0).endsWith(".R")) { + appResource = appArgs.get(0); + appArgs.remove(0); + return buildCommand(env); + } + // When launching the SparkR shell, store the spark-submit arguments in the SPARKR_SUBMIT_ARGS + // env variable. + constructEnvVarArgs(env, "SPARKR_SUBMIT_ARGS"); + + // Set shell.R as R_PROFILE_USER to load the SparkR package when the shell comes up. + String sparkHome = System.getenv("SPARK_HOME"); + env.put("R_PROFILE_USER", + join(File.separator, sparkHome, "R", "lib", "SparkR", "profile", "shell.R")); + + List args = new ArrayList(); + args.add(firstNonEmpty(System.getenv("SPARKR_DRIVER_R"), "R")); + return args; + } + + private void constructEnvVarArgs( + Map env, + String submitArgsEnvVariable) throws IOException { Properties props = loadPropertiesFile(); mergeEnvPathList(env, getLibPathEnvName(), firstNonEmptyValue(SparkLauncher.DRIVER_EXTRA_LIBRARY_PATH, conf, props)); - // Store spark-submit arguments in an environment variable, since there's no way to pass - // them to shell.py on the comand line. StringBuilder submitArgs = new StringBuilder(); for (String arg : buildSparkSubmitArgs()) { if (submitArgs.length() > 0) { submitArgs.append(" "); } - submitArgs.append(quoteForPython(arg)); + submitArgs.append(quoteForCommandString(arg)); } - env.put("PYSPARK_SUBMIT_ARGS", submitArgs.toString()); - - List pyargs = new ArrayList(); - pyargs.add(firstNonEmpty(System.getenv("PYSPARK_DRIVER_PYTHON"), "python")); - String pyOpts = System.getenv("PYSPARK_DRIVER_PYTHON_OPTS"); - if (!isEmpty(pyOpts)) { - pyargs.addAll(parseOptionString(pyOpts)); - } - - return pyargs; + env.put(submitArgsEnvVariable, submitArgs.toString()); } + private boolean isClientMode(Properties userProps) { String userMaster = firstNonEmpty(master, (String) userProps.get(SparkLauncher.SPARK_MASTER)); // Default master is "local[*]", so assume client mode in that case. diff --git a/launcher/src/test/java/org/apache/spark/launcher/CommandBuilderUtilsSuite.java b/launcher/src/test/java/org/apache/spark/launcher/CommandBuilderUtilsSuite.java index dba0203867372..1ae42eed8a3af 100644 --- a/launcher/src/test/java/org/apache/spark/launcher/CommandBuilderUtilsSuite.java +++ b/launcher/src/test/java/org/apache/spark/launcher/CommandBuilderUtilsSuite.java @@ -79,9 +79,9 @@ public void testWindowsBatchQuoting() { @Test public void testPythonArgQuoting() { - assertEquals("\"abc\"", quoteForPython("abc")); - assertEquals("\"a b c\"", quoteForPython("a b c")); - assertEquals("\"a \\\"b\\\" c\"", quoteForPython("a \"b\" c")); + assertEquals("\"abc\"", quoteForCommandString("abc")); + assertEquals("\"a b c\"", quoteForCommandString("a b c")); + assertEquals("\"a \\\"b\\\" c\"", quoteForCommandString("a \"b\" c")); } private void testOpt(String opts, List expected) { diff --git a/launcher/src/test/resources/log4j.properties b/launcher/src/test/resources/log4j.properties index 00c20ad69cd4d..67a6a98217118 100644 --- a/launcher/src/test/resources/log4j.properties +++ b/launcher/src/test/resources/log4j.properties @@ -27,5 +27,5 @@ log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.eclipse.jetty=WARN -org.eclipse.jetty.LEVEL=WARN +log4j.logger.org.spark-project.jetty=WARN +org.spark-project.jetty.LEVEL=WARN diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/FPGrowthModelWrapper.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/FPGrowthModelWrapper.scala new file mode 100644 index 0000000000000..ee933f4cfcafd --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/FPGrowthModelWrapper.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.api.python + +import org.apache.spark.api.java.JavaRDD +import org.apache.spark.mllib.fpm.{FPGrowth, FPGrowthModel} +import org.apache.spark.rdd.RDD + +/** + * A Wrapper of FPGrowthModel to provide helper method for Python + */ +private[python] class FPGrowthModelWrapper(model: FPGrowthModel[Any]) + extends FPGrowthModel(model.freqItemsets) { + + def getFreqItemsets: RDD[Array[Any]] = { + SerDe.fromTuple2RDD(model.freqItemsets.map(x => (x.javaItems, x.freq))) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index 6c386cacfb7ca..1faa3def0e042 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -34,6 +34,7 @@ import org.apache.spark.api.python.SerDeUtil import org.apache.spark.mllib.classification._ import org.apache.spark.mllib.clustering._ import org.apache.spark.mllib.feature._ +import org.apache.spark.mllib.fpm.{FPGrowth, FPGrowthModel} import org.apache.spark.mllib.linalg._ import org.apache.spark.mllib.optimization._ import org.apache.spark.mllib.random.{RandomRDDs => RG} @@ -358,9 +359,7 @@ private[python] class PythonMLLibAPI extends Serializable { val model = new GaussianMixtureModel(weight, gaussians) model.predictSoft(data) } - - - + /** * Java stub for Python mllib ALS.train(). This stub returns a handle * to the Java object instead of the content of the Java object. Extra care @@ -420,6 +419,24 @@ private[python] class PythonMLLibAPI extends Serializable { new MatrixFactorizationModelWrapper(model) } + /** + * Java stub for Python mllib FPGrowth.train(). This stub returns a handle + * to the Java object instead of the content of the Java object. Extra care + * needs to be taken in the Python code to ensure it gets freed on exit; see + * the Py4J documentation. + */ + def trainFPGrowthModel( + data: JavaRDD[java.lang.Iterable[Any]], + minSupport: Double, + numPartitions: Int): FPGrowthModel[Any] = { + val fpg = new FPGrowth() + .setMinSupport(minSupport) + .setNumPartitions(numPartitions) + + val model = fpg.run(data.rdd.map(_.asScala.toArray)) + new FPGrowthModelWrapper(model) + } + /** * Java stub for Normalizer.transform() */ diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala index d1a174063caba..3fa5e068d16d4 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala @@ -87,6 +87,9 @@ sealed trait Matrix extends Serializable { /** A human readable representation of the matrix */ override def toString: String = toBreeze.toString() + /** A human readable representation of the matrix with maximum lines and width */ + def toString(maxLines: Int, maxLineWidth: Int): String = toBreeze.toString(maxLines, maxLineWidth) + /** Map the values of this matrix using a function. Generates a new matrix. Performs the * function on only the backing array. For example, an operation such as addition or * subtraction will only be performed on the non-zero values in a `SparseMatrix`. */ diff --git a/mllib/src/test/resources/log4j.properties b/mllib/src/test/resources/log4j.properties index 9697237bfa1a3..75e3b53a093f6 100644 --- a/mllib/src/test/resources/log4j.properties +++ b/mllib/src/test/resources/log4j.properties @@ -24,5 +24,5 @@ log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.eclipse.jetty=WARN +log4j.logger.org.spark-project.jetty=WARN diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala index 0d2cec58e2c03..86119ec38101e 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala @@ -439,4 +439,20 @@ class MatricesSuite extends FunSuite { assert(mUDT.typeName == "matrix") assert(mUDT.simpleString == "matrix") } + + test("toString") { + val empty = Matrices.ones(0, 0) + empty.toString(0, 0) + + val mat = Matrices.rand(5, 10, new Random()) + mat.toString(-1, -5) + mat.toString(0, 0) + mat.toString(Int.MinValue, Int.MinValue) + mat.toString(Int.MaxValue, Int.MaxValue) + var lines = mat.toString(6, 50).lines.toArray + assert(lines.size == 5 && lines.forall(_.size <= 50)) + + lines = mat.toString(5, 100).lines.toArray + assert(lines.size == 5 && lines.forall(_.size <= 100)) + } } diff --git a/pom.xml b/pom.xml index 42bd926a2fcb8..70e297c4f082a 100644 --- a/pom.xml +++ b/pom.xml @@ -1749,5 +1749,8 @@ parquet-provided + + sparkr + diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index c2d828f982fe0..1564babefa62f 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -64,6 +64,10 @@ object MimaExcludes { // SPARK-6492 Fix deadlock in SparkContext.stop() ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.org$" + "apache$spark$SparkContext$$SPARK_CONTEXT_CONSTRUCTOR_LOCK") + )++ Seq( + // SPARK-6693 add tostring with max lines and width for matrix + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.Matrix.toString") ) case v if v.startsWith("1.3") => diff --git a/python/docs/pyspark.mllib.rst b/python/docs/pyspark.mllib.rst index 15101470afc07..26ece4c2c389a 100644 --- a/python/docs/pyspark.mllib.rst +++ b/python/docs/pyspark.mllib.rst @@ -31,6 +31,13 @@ pyspark.mllib.feature module :undoc-members: :show-inheritance: +pyspark.mllib.fpm module +------------------------ + +.. automodule:: pyspark.mllib.fpm + :members: + :undoc-members: + pyspark.mllib.linalg module --------------------------- diff --git a/python/pyspark/join.py b/python/pyspark/join.py index efc1ef9396412..c3491defb2b29 100644 --- a/python/pyspark/join.py +++ b/python/pyspark/join.py @@ -48,7 +48,7 @@ def dispatch(seq): vbuf.append(v) elif n == 2: wbuf.append(v) - return [(v, w) for v in vbuf for w in wbuf] + return ((v, w) for v in vbuf for w in wbuf) return _do_python_join(rdd, other, numPartitions, dispatch) @@ -62,7 +62,7 @@ def dispatch(seq): wbuf.append(v) if not vbuf: vbuf.append(None) - return [(v, w) for v in vbuf for w in wbuf] + return ((v, w) for v in vbuf for w in wbuf) return _do_python_join(rdd, other, numPartitions, dispatch) @@ -76,7 +76,7 @@ def dispatch(seq): wbuf.append(v) if not wbuf: wbuf.append(None) - return [(v, w) for v in vbuf for w in wbuf] + return ((v, w) for v in vbuf for w in wbuf) return _do_python_join(rdd, other, numPartitions, dispatch) @@ -104,8 +104,9 @@ def make_mapper(i): rdd_len = len(vrdds) def dispatch(seq): - bufs = [[] for i in range(rdd_len)] - for (n, v) in seq: + bufs = [[] for _ in range(rdd_len)] + for n, v in seq: bufs[n].append(v) - return tuple(map(ResultIterable, bufs)) + return tuple(ResultIterable(vs) for vs in bufs) + return union_vrdds.groupByKey(numPartitions).mapValues(dispatch) diff --git a/python/pyspark/mllib/__init__.py b/python/pyspark/mllib/__init__.py index 6449800d9c120..f2ef573fe9f6f 100644 --- a/python/pyspark/mllib/__init__.py +++ b/python/pyspark/mllib/__init__.py @@ -25,7 +25,7 @@ if numpy.version.version < '1.4': raise Exception("MLlib requires NumPy 1.4+") -__all__ = ['classification', 'clustering', 'feature', 'linalg', 'random', +__all__ = ['classification', 'clustering', 'feature', 'fpm', 'linalg', 'random', 'recommendation', 'regression', 'stat', 'tree', 'util'] import sys diff --git a/python/pyspark/mllib/fpm.py b/python/pyspark/mllib/fpm.py new file mode 100644 index 0000000000000..3aa6d79d7093c --- /dev/null +++ b/python/pyspark/mllib/fpm.py @@ -0,0 +1,81 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from pyspark import SparkContext +from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc, inherit_doc + +__all__ = ['FPGrowth', 'FPGrowthModel'] + + +@inherit_doc +class FPGrowthModel(JavaModelWrapper): + + """ + .. note:: Experimental + + A FP-Growth model for mining frequent itemsets + using the Parallel FP-Growth algorithm. + + >>> data = [["a", "b", "c"], ["a", "b", "d", "e"], ["a", "c", "e"], ["a", "c", "f"]] + >>> rdd = sc.parallelize(data, 2) + >>> model = FPGrowth.train(rdd, 0.6, 2) + >>> sorted(model.freqItemsets().collect()) + [([u'a'], 4), ([u'c'], 3), ([u'c', u'a'], 3)] + """ + + def freqItemsets(self): + """ + Get the frequent itemsets of this model + """ + return self.call("getFreqItemsets") + + +class FPGrowth(object): + """ + .. note:: Experimental + + A Parallel FP-growth algorithm to mine frequent itemsets. + """ + + @classmethod + def train(cls, data, minSupport=0.3, numPartitions=-1): + """ + Computes an FP-Growth model that contains frequent itemsets. + :param data: The input data set, each element + contains a transaction. + :param minSupport: The minimal support level + (default: `0.3`). + :param numPartitions: The number of partitions used by parallel + FP-growth (default: same as input data). + """ + model = callMLlibFunc("trainFPGrowthModel", data, float(minSupport), int(numPartitions)) + return FPGrowthModel(model) + + +def _test(): + import doctest + import pyspark.mllib.fpm + globs = pyspark.mllib.fpm.__dict__.copy() + globs['sc'] = SparkContext('local[4]', 'PythonTest') + (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) + globs['sc'].stop() + if failure_count: + exit(-1) + + +if __name__ == "__main__": + _test() diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 2d05611321ed6..1b18789040360 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -41,7 +41,7 @@ from pyspark.storagelevel import StorageLevel from pyspark.resultiterable import ResultIterable from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, \ - get_used_memory, ExternalSorter + get_used_memory, ExternalSorter, ExternalGroupBy from pyspark.traceback_utils import SCCallSiteSync from py4j.java_collections import ListConverter, MapConverter @@ -573,8 +573,8 @@ def sortByKey(self, ascending=True, numPartitions=None, keyfunc=lambda x: x): if numPartitions is None: numPartitions = self._defaultReducePartitions() - spill = (self.ctx._conf.get("spark.shuffle.spill", 'True').lower() == 'true') - memory = _parse_memory(self.ctx._conf.get("spark.python.worker.memory", "512m")) + spill = self._can_spill() + memory = self._memory_limit() serializer = self._jrdd_deserializer def sortPartition(iterator): @@ -1699,10 +1699,8 @@ def combineByKey(self, createCombiner, mergeValue, mergeCombiners, numPartitions = self._defaultReducePartitions() serializer = self.ctx.serializer - spill = (self.ctx._conf.get("spark.shuffle.spill", 'True').lower() - == 'true') - memory = _parse_memory(self.ctx._conf.get( - "spark.python.worker.memory", "512m")) + spill = self._can_spill() + memory = self._memory_limit() agg = Aggregator(createCombiner, mergeValue, mergeCombiners) def combineLocally(iterator): @@ -1755,21 +1753,28 @@ def createZero(): return self.combineByKey(lambda v: func(createZero(), v), func, func, numPartitions) + def _can_spill(self): + return self.ctx._conf.get("spark.shuffle.spill", "True").lower() == "true" + + def _memory_limit(self): + return _parse_memory(self.ctx._conf.get("spark.python.worker.memory", "512m")) + # TODO: support variant with custom partitioner def groupByKey(self, numPartitions=None): """ Group the values for each key in the RDD into a single sequence. - Hash-partitions the resulting RDD with into numPartitions partitions. + Hash-partitions the resulting RDD with numPartitions partitions. Note: If you are grouping in order to perform an aggregation (such as a sum or average) over each key, using reduceByKey or aggregateByKey will provide much better performance. >>> x = sc.parallelize([("a", 1), ("b", 1), ("a", 1)]) - >>> map((lambda (x,y): (x, list(y))), sorted(x.groupByKey().collect())) + >>> sorted(x.groupByKey().mapValues(len).collect()) + [('a', 2), ('b', 1)] + >>> sorted(x.groupByKey().mapValues(list).collect()) [('a', [1, 1]), ('b', [1])] """ - def createCombiner(x): return [x] @@ -1781,8 +1786,27 @@ def mergeCombiners(a, b): a.extend(b) return a - return self.combineByKey(createCombiner, mergeValue, mergeCombiners, - numPartitions).mapValues(lambda x: ResultIterable(x)) + spill = self._can_spill() + memory = self._memory_limit() + serializer = self._jrdd_deserializer + agg = Aggregator(createCombiner, mergeValue, mergeCombiners) + + def combine(iterator): + merger = ExternalMerger(agg, memory * 0.9, serializer) \ + if spill else InMemoryMerger(agg) + merger.mergeValues(iterator) + return merger.iteritems() + + locally_combined = self.mapPartitions(combine, preservesPartitioning=True) + shuffled = locally_combined.partitionBy(numPartitions) + + def groupByKey(it): + merger = ExternalGroupBy(agg, memory, serializer)\ + if spill else InMemoryMerger(agg) + merger.mergeCombiners(it) + return merger.iteritems() + + return shuffled.mapPartitions(groupByKey, True).mapValues(ResultIterable) def flatMapValues(self, f): """ diff --git a/python/pyspark/resultiterable.py b/python/pyspark/resultiterable.py index ef04c82866e6c..1ab5ce14c3531 100644 --- a/python/pyspark/resultiterable.py +++ b/python/pyspark/resultiterable.py @@ -15,15 +15,16 @@ # limitations under the License. # -__all__ = ["ResultIterable"] - import collections +__all__ = ["ResultIterable"] + class ResultIterable(collections.Iterable): """ - A special result iterable. This is used because the standard iterator can not be pickled + A special result iterable. This is used because the standard + iterator can not be pickled """ def __init__(self, data): diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 0ffb41d02f6f6..4afa82f4b2973 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -220,6 +220,29 @@ def __repr__(self): return "BatchedSerializer(%s, %d)" % (str(self.serializer), self.batchSize) +class FlattenedValuesSerializer(BatchedSerializer): + + """ + Serializes a stream of list of pairs, split the list of values + which contain more than a certain number of objects to make them + have similar sizes. + """ + def __init__(self, serializer, batchSize=10): + BatchedSerializer.__init__(self, serializer, batchSize) + + def _batched(self, iterator): + n = self.batchSize + for key, values in iterator: + for i in xrange(0, len(values), n): + yield key, values[i:i + n] + + def load_stream(self, stream): + return self.serializer.load_stream(stream) + + def __repr__(self): + return "FlattenedValuesSerializer(%d)" % self.batchSize + + class AutoBatchedSerializer(BatchedSerializer): """ Choose the size of batch automatically based on the size of object @@ -251,7 +274,7 @@ def __eq__(self, other): return (isinstance(other, AutoBatchedSerializer) and other.serializer == self.serializer and other.bestSize == self.bestSize) - def __str__(self): + def __repr__(self): return "AutoBatchedSerializer(%s)" % str(self.serializer) diff --git a/python/pyspark/shuffle.py b/python/pyspark/shuffle.py index 10a7ccd502000..8a6fc627eb383 100644 --- a/python/pyspark/shuffle.py +++ b/python/pyspark/shuffle.py @@ -16,28 +16,35 @@ # import os -import sys import platform import shutil import warnings import gc import itertools +import operator import random import pyspark.heapq3 as heapq -from pyspark.serializers import AutoBatchedSerializer, PickleSerializer +from pyspark.serializers import BatchedSerializer, PickleSerializer, FlattenedValuesSerializer, \ + CompressedSerializer, AutoBatchedSerializer + try: import psutil + process = None + def get_used_memory(): """ Return the used memory in MB """ - process = psutil.Process(os.getpid()) + global process + if process is None or process._pid != os.getpid(): + process = psutil.Process(os.getpid()) if hasattr(process, "memory_info"): info = process.memory_info() else: info = process.get_memory_info() return info.rss >> 20 + except ImportError: def get_used_memory(): @@ -46,6 +53,7 @@ def get_used_memory(): for line in open('/proc/self/status'): if line.startswith('VmRSS:'): return int(line.split()[1]) >> 10 + else: warnings.warn("Please install psutil to have better " "support with spilling") @@ -54,6 +62,7 @@ def get_used_memory(): rss = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss return rss >> 20 # TODO: support windows + return 0 @@ -148,10 +157,16 @@ def mergeCombiners(self, iterator): d[k] = comb(d[k], v) if k in d else v def iteritems(self): - """ Return the merged items ad iterator """ + """ Return the merged items as iterator """ return self.data.iteritems() +def _compressed_serializer(self, serializer=None): + # always use PickleSerializer to simplify implementation + ser = PickleSerializer() + return AutoBatchedSerializer(CompressedSerializer(ser)) + + class ExternalMerger(Merger): """ @@ -173,7 +188,7 @@ class ExternalMerger(Merger): dict. Repeat this again until combine all the items. - Before return any items, it will load each partition and - combine them seperately. Yield them before loading next + combine them separately. Yield them before loading next partition. - During loading a partition, if the memory goes over limit, @@ -182,7 +197,7 @@ class ExternalMerger(Merger): `data` and `pdata` are used to hold the merged items in memory. At first, all the data are merged into `data`. Once the used - memory goes over limit, the items in `data` are dumped indo + memory goes over limit, the items in `data` are dumped into disks, `data` will be cleared, all rest of items will be merged into `pdata` and then dumped into disks. Before returning, all the items in `pdata` will be dumped into disks. @@ -193,16 +208,16 @@ class ExternalMerger(Merger): >>> agg = SimpleAggregator(lambda x, y: x + y) >>> merger = ExternalMerger(agg, 10) >>> N = 10000 - >>> merger.mergeValues(zip(xrange(N), xrange(N)) * 10) + >>> merger.mergeValues(zip(xrange(N), xrange(N))) >>> assert merger.spills > 0 >>> sum(v for k,v in merger.iteritems()) - 499950000 + 49995000 >>> merger = ExternalMerger(agg, 10) - >>> merger.mergeCombiners(zip(xrange(N), xrange(N)) * 10) + >>> merger.mergeCombiners(zip(xrange(N), xrange(N))) >>> assert merger.spills > 0 >>> sum(v for k,v in merger.iteritems()) - 499950000 + 49995000 """ # the max total partitions created recursively @@ -212,8 +227,7 @@ def __init__(self, aggregator, memory_limit=512, serializer=None, localdirs=None, scale=1, partitions=59, batch=1000): Merger.__init__(self, aggregator) self.memory_limit = memory_limit - # default serializer is only used for tests - self.serializer = serializer or AutoBatchedSerializer(PickleSerializer()) + self.serializer = _compressed_serializer(serializer) self.localdirs = localdirs or _get_local_dirs(str(id(self))) # number of partitions when spill data into disks self.partitions = partitions @@ -221,7 +235,7 @@ def __init__(self, aggregator, memory_limit=512, serializer=None, self.batch = batch # scale is used to scale down the hash of key for recursive hash map self.scale = scale - # unpartitioned merged data + # un-partitioned merged data self.data = {} # partitioned merged data, list of dicts self.pdata = [] @@ -244,72 +258,63 @@ def _next_limit(self): def mergeValues(self, iterator): """ Combine the items by creator and combiner """ - iterator = iter(iterator) # speedup attribute lookup creator, comb = self.agg.createCombiner, self.agg.mergeValue - d, c, batch = self.data, 0, self.batch + c, data, pdata, hfun, batch = 0, self.data, self.pdata, self._partition, self.batch + limit = self.memory_limit for k, v in iterator: + d = pdata[hfun(k)] if pdata else data d[k] = comb(d[k], v) if k in d else creator(v) c += 1 - if c % batch == 0 and get_used_memory() > self.memory_limit: - self._spill() - self._partitioned_mergeValues(iterator, self._next_limit()) - break + if c >= batch: + if get_used_memory() >= limit: + self._spill() + limit = self._next_limit() + batch /= 2 + c = 0 + else: + batch *= 1.5 + + if get_used_memory() >= limit: + self._spill() def _partition(self, key): """ Return the partition for key """ return hash((key, self._seed)) % self.partitions - def _partitioned_mergeValues(self, iterator, limit=0): - """ Partition the items by key, then combine them """ - # speedup attribute lookup - creator, comb = self.agg.createCombiner, self.agg.mergeValue - c, pdata, hfun, batch = 0, self.pdata, self._partition, self.batch - - for k, v in iterator: - d = pdata[hfun(k)] - d[k] = comb(d[k], v) if k in d else creator(v) - if not limit: - continue - - c += 1 - if c % batch == 0 and get_used_memory() > limit: - self._spill() - limit = self._next_limit() + def _object_size(self, obj): + """ How much of memory for this obj, assume that all the objects + consume similar bytes of memory + """ + return 1 - def mergeCombiners(self, iterator, check=True): + def mergeCombiners(self, iterator, limit=None): """ Merge (K,V) pair by mergeCombiner """ - iterator = iter(iterator) + if limit is None: + limit = self.memory_limit # speedup attribute lookup - d, comb, batch = self.data, self.agg.mergeCombiners, self.batch - c = 0 - for k, v in iterator: - d[k] = comb(d[k], v) if k in d else v - if not check: - continue - - c += 1 - if c % batch == 0 and get_used_memory() > self.memory_limit: - self._spill() - self._partitioned_mergeCombiners(iterator, self._next_limit()) - break - - def _partitioned_mergeCombiners(self, iterator, limit=0): - """ Partition the items by key, then merge them """ - comb, pdata = self.agg.mergeCombiners, self.pdata - c, hfun = 0, self._partition + comb, hfun, objsize = self.agg.mergeCombiners, self._partition, self._object_size + c, data, pdata, batch = 0, self.data, self.pdata, self.batch for k, v in iterator: - d = pdata[hfun(k)] + d = pdata[hfun(k)] if pdata else data d[k] = comb(d[k], v) if k in d else v if not limit: continue - c += 1 - if c % self.batch == 0 and get_used_memory() > limit: - self._spill() - limit = self._next_limit() + c += objsize(v) + if c > batch: + if get_used_memory() > limit: + self._spill() + limit = self._next_limit() + batch /= 2 + c = 0 + else: + batch *= 1.5 + + if limit and get_used_memory() >= limit: + self._spill() def _spill(self): """ @@ -335,7 +340,7 @@ def _spill(self): for k, v in self.data.iteritems(): h = self._partition(k) - # put one item in batch, make it compatitable with load_stream + # put one item in batch, make it compatible with load_stream # it will increase the memory if dump them in batch self.serializer.dump_stream([(k, v)], streams[h]) @@ -344,7 +349,7 @@ def _spill(self): s.close() self.data.clear() - self.pdata = [{} for i in range(self.partitions)] + self.pdata.extend([{} for i in range(self.partitions)]) else: for i in range(self.partitions): @@ -370,29 +375,12 @@ def _external_items(self): assert not self.data if any(self.pdata): self._spill() - hard_limit = self._next_limit() + # disable partitioning and spilling when merge combiners from disk + self.pdata = [] try: for i in range(self.partitions): - self.data = {} - for j in range(self.spills): - path = self._get_spill_dir(j) - p = os.path.join(path, str(i)) - # do not check memory during merging - self.mergeCombiners(self.serializer.load_stream(open(p)), - False) - - # limit the total partitions - if (self.scale * self.partitions < self.MAX_TOTAL_PARTITIONS - and j < self.spills - 1 - and get_used_memory() > hard_limit): - self.data.clear() # will read from disk again - gc.collect() # release the memory as much as possible - for v in self._recursive_merged_items(i): - yield v - return - - for v in self.data.iteritems(): + for v in self._merged_items(i): yield v self.data.clear() @@ -400,53 +388,56 @@ def _external_items(self): for j in range(self.spills): path = self._get_spill_dir(j) os.remove(os.path.join(path, str(i))) - finally: self._cleanup() - def _cleanup(self): - """ Clean up all the files in disks """ - for d in self.localdirs: - shutil.rmtree(d, True) + def _merged_items(self, index): + self.data = {} + limit = self._next_limit() + for j in range(self.spills): + path = self._get_spill_dir(j) + p = os.path.join(path, str(index)) + # do not check memory during merging + self.mergeCombiners(self.serializer.load_stream(open(p)), 0) + + # limit the total partitions + if (self.scale * self.partitions < self.MAX_TOTAL_PARTITIONS + and j < self.spills - 1 + and get_used_memory() > limit): + self.data.clear() # will read from disk again + gc.collect() # release the memory as much as possible + return self._recursive_merged_items(index) - def _recursive_merged_items(self, start): + return self.data.iteritems() + + def _recursive_merged_items(self, index): """ merge the partitioned items and return the as iterator If one partition can not be fit in memory, then them will be partitioned and merged recursively. """ - # make sure all the data are dumps into disks. - assert not self.data - if any(self.pdata): - self._spill() - assert self.spills > 0 - - for i in range(start, self.partitions): - subdirs = [os.path.join(d, "parts", str(i)) - for d in self.localdirs] - m = ExternalMerger(self.agg, self.memory_limit, self.serializer, - subdirs, self.scale * self.partitions, self.partitions) - m.pdata = [{} for _ in range(self.partitions)] - limit = self._next_limit() - - for j in range(self.spills): - path = self._get_spill_dir(j) - p = os.path.join(path, str(i)) - m._partitioned_mergeCombiners( - self.serializer.load_stream(open(p))) - - if get_used_memory() > limit: - m._spill() - limit = self._next_limit() + subdirs = [os.path.join(d, "parts", str(index)) for d in self.localdirs] + m = ExternalMerger(self.agg, self.memory_limit, self.serializer, subdirs, + self.scale * self.partitions, self.partitions, self.batch) + m.pdata = [{} for _ in range(self.partitions)] + limit = self._next_limit() + + for j in range(self.spills): + path = self._get_spill_dir(j) + p = os.path.join(path, str(index)) + m.mergeCombiners(self.serializer.load_stream(open(p)), 0) + + if get_used_memory() > limit: + m._spill() + limit = self._next_limit() - for v in m._external_items(): - yield v + return m._external_items() - # remove the merged partition - for j in range(self.spills): - path = self._get_spill_dir(j) - os.remove(os.path.join(path, str(i))) + def _cleanup(self): + """ Clean up all the files in disks """ + for d in self.localdirs: + shutil.rmtree(d, True) class ExternalSorter(object): @@ -457,6 +448,7 @@ class ExternalSorter(object): The spilling will only happen when the used memory goes above the limit. + >>> sorter = ExternalSorter(1) # 1M >>> import random >>> l = range(1024) @@ -469,7 +461,7 @@ class ExternalSorter(object): def __init__(self, memory_limit, serializer=None): self.memory_limit = memory_limit self.local_dirs = _get_local_dirs("sort") - self.serializer = serializer or AutoBatchedSerializer(PickleSerializer()) + self.serializer = _compressed_serializer(serializer) def _get_path(self, n): """ Choose one directory for spill by number n """ @@ -515,6 +507,7 @@ def sorted(self, iterator, key=None, reverse=False): limit = self._next_limit() MemoryBytesSpilled += (used_memory - get_used_memory()) << 20 DiskBytesSpilled += os.path.getsize(path) + os.unlink(path) # data will be deleted after close elif not chunks: batch = min(batch * 2, 10000) @@ -529,6 +522,310 @@ def sorted(self, iterator, key=None, reverse=False): return heapq.merge(chunks, key=key, reverse=reverse) +class ExternalList(object): + """ + ExternalList can have many items which cannot be hold in memory in + the same time. + + >>> l = ExternalList(range(100)) + >>> len(l) + 100 + >>> l.append(10) + >>> len(l) + 101 + >>> for i in range(20240): + ... l.append(i) + >>> len(l) + 20341 + >>> import pickle + >>> l2 = pickle.loads(pickle.dumps(l)) + >>> len(l2) + 20341 + >>> list(l2)[100] + 10 + """ + LIMIT = 10240 + + def __init__(self, values): + self.values = values + self.count = len(values) + self._file = None + self._ser = None + + def __getstate__(self): + if self._file is not None: + self._file.flush() + f = os.fdopen(os.dup(self._file.fileno())) + f.seek(0) + serialized = f.read() + else: + serialized = '' + return self.values, self.count, serialized + + def __setstate__(self, item): + self.values, self.count, serialized = item + if serialized: + self._open_file() + self._file.write(serialized) + else: + self._file = None + self._ser = None + + def __iter__(self): + if self._file is not None: + self._file.flush() + # read all items from disks first + with os.fdopen(os.dup(self._file.fileno()), 'r') as f: + f.seek(0) + for v in self._ser.load_stream(f): + yield v + + for v in self.values: + yield v + + def __len__(self): + return self.count + + def append(self, value): + self.values.append(value) + self.count += 1 + # dump them into disk if the key is huge + if len(self.values) >= self.LIMIT: + self._spill() + + def _open_file(self): + dirs = _get_local_dirs("objects") + d = dirs[id(self) % len(dirs)] + if not os.path.exists(d): + os.makedirs(d) + p = os.path.join(d, str(id)) + self._file = open(p, "w+", 65536) + self._ser = BatchedSerializer(CompressedSerializer(PickleSerializer()), 1024) + os.unlink(p) + + def _spill(self): + """ dump the values into disk """ + global MemoryBytesSpilled, DiskBytesSpilled + if self._file is None: + self._open_file() + + used_memory = get_used_memory() + pos = self._file.tell() + self._ser.dump_stream(self.values, self._file) + self.values = [] + gc.collect() + DiskBytesSpilled += self._file.tell() - pos + MemoryBytesSpilled += (used_memory - get_used_memory()) << 20 + + +class ExternalListOfList(ExternalList): + """ + An external list for list. + + >>> l = ExternalListOfList([[i, i] for i in range(100)]) + >>> len(l) + 200 + >>> l.append(range(10)) + >>> len(l) + 210 + >>> len(list(l)) + 210 + """ + + def __init__(self, values): + ExternalList.__init__(self, values) + self.count = sum(len(i) for i in values) + + def append(self, value): + ExternalList.append(self, value) + # already counted 1 in ExternalList.append + self.count += len(value) - 1 + + def __iter__(self): + for values in ExternalList.__iter__(self): + for v in values: + yield v + + +class GroupByKey(object): + """ + Group a sorted iterator as [(k1, it1), (k2, it2), ...] + + >>> k = [i/3 for i in range(6)] + >>> v = [[i] for i in range(6)] + >>> g = GroupByKey(iter(zip(k, v))) + >>> [(k, list(it)) for k, it in g] + [(0, [0, 1, 2]), (1, [3, 4, 5])] + """ + + def __init__(self, iterator): + self.iterator = iter(iterator) + self.next_item = None + + def __iter__(self): + return self + + def next(self): + key, value = self.next_item if self.next_item else next(self.iterator) + values = ExternalListOfList([value]) + try: + while True: + k, v = next(self.iterator) + if k != key: + self.next_item = (k, v) + break + values.append(v) + except StopIteration: + self.next_item = None + return key, values + + +class ExternalGroupBy(ExternalMerger): + + """ + Group by the items by key. If any partition of them can not been + hold in memory, it will do sort based group by. + + This class works as follows: + + - It repeatedly group the items by key and save them in one dict in + memory. + + - When the used memory goes above memory limit, it will split + the combined data into partitions by hash code, dump them + into disk, one file per partition. If the number of keys + in one partitions is smaller than 1000, it will sort them + by key before dumping into disk. + + - Then it goes through the rest of the iterator, group items + by key into different dict by hash. Until the used memory goes over + memory limit, it dump all the dicts into disks, one file per + dict. Repeat this again until combine all the items. It + also will try to sort the items by key in each partition + before dumping into disks. + + - It will yield the grouped items partitions by partitions. + If the data in one partitions can be hold in memory, then it + will load and combine them in memory and yield. + + - If the dataset in one partition cannot be hold in memory, + it will sort them first. If all the files are already sorted, + it merge them by heap.merge(), so it will do external sort + for all the files. + + - After sorting, `GroupByKey` class will put all the continuous + items with the same key as a group, yield the values as + an iterator. + """ + SORT_KEY_LIMIT = 1000 + + def flattened_serializer(self): + assert isinstance(self.serializer, BatchedSerializer) + ser = self.serializer + return FlattenedValuesSerializer(ser, 20) + + def _object_size(self, obj): + return len(obj) + + def _spill(self): + """ + dump already partitioned data into disks. + """ + global MemoryBytesSpilled, DiskBytesSpilled + path = self._get_spill_dir(self.spills) + if not os.path.exists(path): + os.makedirs(path) + + used_memory = get_used_memory() + if not self.pdata: + # The data has not been partitioned, it will iterator the + # data once, write them into different files, has no + # additional memory. It only called when the memory goes + # above limit at the first time. + + # open all the files for writing + streams = [open(os.path.join(path, str(i)), 'w') + for i in range(self.partitions)] + + # If the number of keys is small, then the overhead of sort is small + # sort them before dumping into disks + self._sorted = len(self.data) < self.SORT_KEY_LIMIT + if self._sorted: + self.serializer = self.flattened_serializer() + for k in sorted(self.data.keys()): + h = self._partition(k) + self.serializer.dump_stream([(k, self.data[k])], streams[h]) + else: + for k, v in self.data.iteritems(): + h = self._partition(k) + self.serializer.dump_stream([(k, v)], streams[h]) + + for s in streams: + DiskBytesSpilled += s.tell() + s.close() + + self.data.clear() + # self.pdata is cached in `mergeValues` and `mergeCombiners` + self.pdata.extend([{} for i in range(self.partitions)]) + + else: + for i in range(self.partitions): + p = os.path.join(path, str(i)) + with open(p, "w") as f: + # dump items in batch + if self._sorted: + # sort by key only (stable) + sorted_items = sorted(self.pdata[i].iteritems(), key=operator.itemgetter(0)) + self.serializer.dump_stream(sorted_items, f) + else: + self.serializer.dump_stream(self.pdata[i].iteritems(), f) + self.pdata[i].clear() + DiskBytesSpilled += os.path.getsize(p) + + self.spills += 1 + gc.collect() # release the memory as much as possible + MemoryBytesSpilled += (used_memory - get_used_memory()) << 20 + + def _merged_items(self, index): + size = sum(os.path.getsize(os.path.join(self._get_spill_dir(j), str(index))) + for j in range(self.spills)) + # if the memory can not hold all the partition, + # then use sort based merge. Because of compression, + # the data on disks will be much smaller than needed memory + if (size >> 20) >= self.memory_limit / 10: + return self._merge_sorted_items(index) + + self.data = {} + for j in range(self.spills): + path = self._get_spill_dir(j) + p = os.path.join(path, str(index)) + # do not check memory during merging + self.mergeCombiners(self.serializer.load_stream(open(p)), 0) + return self.data.iteritems() + + def _merge_sorted_items(self, index): + """ load a partition from disk, then sort and group by key """ + def load_partition(j): + path = self._get_spill_dir(j) + p = os.path.join(path, str(index)) + return self.serializer.load_stream(open(p, 'r', 65536)) + + disk_items = [load_partition(j) for j in range(self.spills)] + + if self._sorted: + # all the partitions are already sorted + sorted_items = heapq.merge(disk_items, key=operator.itemgetter(0)) + + else: + # Flatten the combined values, so it will not consume huge + # memory during merging sort. + ser = self.flattened_serializer() + sorter = ExternalSorter(self.memory_limit, ser) + sorted_items = sorter.sorted(itertools.chain(*disk_items), + key=operator.itemgetter(0)) + return ((k, vs) for k, vs in GroupByKey(sorted_items)) + + if __name__ == "__main__": import doctest doctest.testmod() diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index dd8d3b1c53733..0bd5d20f7877f 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -31,6 +31,7 @@ import time import zipfile import random +import itertools import threading import hashlib @@ -76,7 +77,7 @@ class MergerTests(unittest.TestCase): def setUp(self): - self.N = 1 << 14 + self.N = 1 << 12 self.l = [i for i in xrange(self.N)] self.data = zip(self.l, self.l) self.agg = Aggregator(lambda x: [x], @@ -108,7 +109,7 @@ def test_small_dataset(self): sum(xrange(self.N))) def test_medium_dataset(self): - m = ExternalMerger(self.agg, 10) + m = ExternalMerger(self.agg, 30) m.mergeValues(self.data) self.assertTrue(m.spills >= 1) self.assertEqual(sum(sum(v) for k, v in m.iteritems()), @@ -124,10 +125,36 @@ def test_huge_dataset(self): m = ExternalMerger(self.agg, 10, partitions=3) m.mergeCombiners(map(lambda (k, v): (k, [str(v)]), self.data * 10)) self.assertTrue(m.spills >= 1) - self.assertEqual(sum(len(v) for k, v in m._recursive_merged_items(0)), + self.assertEqual(sum(len(v) for k, v in m.iteritems()), self.N * 10) m._cleanup() + def test_group_by_key(self): + + def gen_data(N, step): + for i in range(1, N + 1, step): + for j in range(i): + yield (i, [j]) + + def gen_gs(N, step=1): + return shuffle.GroupByKey(gen_data(N, step)) + + self.assertEqual(1, len(list(gen_gs(1)))) + self.assertEqual(2, len(list(gen_gs(2)))) + self.assertEqual(100, len(list(gen_gs(100)))) + self.assertEqual(range(1, 101), [k for k, _ in gen_gs(100)]) + self.assertTrue(all(range(k) == list(vs) for k, vs in gen_gs(100))) + + for k, vs in gen_gs(50002, 10000): + self.assertEqual(k, len(vs)) + self.assertEqual(range(k), list(vs)) + + ser = PickleSerializer() + l = ser.loads(ser.dumps(list(gen_gs(50002, 30000)))) + for k, vs in l: + self.assertEqual(k, len(vs)) + self.assertEqual(range(k), list(vs)) + class SorterTests(unittest.TestCase): def test_in_memory_sort(self): @@ -702,6 +729,21 @@ def test_distinct(self): self.assertEquals(result.getNumPartitions(), 5) self.assertEquals(result.count(), 3) + def test_external_group_by_key(self): + self.sc._conf.set("spark.python.worker.memory", "5m") + N = 200001 + kv = self.sc.parallelize(range(N)).map(lambda x: (x % 3, x)) + gkv = kv.groupByKey().cache() + self.assertEqual(3, gkv.count()) + filtered = gkv.filter(lambda (k, vs): k == 1) + self.assertEqual(1, filtered.count()) + self.assertEqual([(1, N/3)], filtered.mapValues(len).collect()) + self.assertEqual([(N/3, N/3)], + filtered.values().map(lambda x: (len(x), len(list(x)))).collect()) + result = filtered.collect()[0][1] + self.assertEqual(N/3, len(result)) + self.assertTrue(isinstance(result.data, shuffle.ExternalList)) + def test_sort_on_empty_rdd(self): self.assertEqual([], self.sc.parallelize(zip([], [])).sortByKey().collect()) @@ -752,9 +794,9 @@ def test_narrow_dependency_in_join(self): self.assertEqual(rdd.getNumPartitions() + 2, parted.union(rdd).getNumPartitions()) self.assertEqual(rdd.getNumPartitions() + 2, rdd.union(parted).getNumPartitions()) - self.sc.setJobGroup("test1", "test", True) tracker = self.sc.statusTracker() + self.sc.setJobGroup("test1", "test", True) d = sorted(parted.join(parted).collect()) self.assertEqual(10, len(d)) self.assertEqual((0, (0, 0)), d[0]) diff --git a/python/run-tests b/python/run-tests index b7630c356cfae..f569a56fb7a9a 100755 --- a/python/run-tests +++ b/python/run-tests @@ -77,6 +77,7 @@ function run_mllib_tests() { run_test "pyspark/mllib/clustering.py" run_test "pyspark/mllib/evaluation.py" run_test "pyspark/mllib/feature.py" + run_test "pyspark/mllib/fpm.py" run_test "pyspark/mllib/linalg.py" run_test "pyspark/mllib/rand.py" run_test "pyspark/mllib/recommendation.py" diff --git a/repl/src/test/resources/log4j.properties b/repl/src/test/resources/log4j.properties index e7e4a4113174a..e2ee9c963a4da 100644 --- a/repl/src/test/resources/log4j.properties +++ b/repl/src/test/resources/log4j.properties @@ -24,4 +24,4 @@ log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.eclipse.jetty=WARN +log4j.logger.org.spark-project.jetty=WARN diff --git a/sql/catalyst/src/test/resources/log4j.properties b/sql/catalyst/src/test/resources/log4j.properties index 287c8e3563503..eb3b1999eb996 100644 --- a/sql/catalyst/src/test/resources/log4j.properties +++ b/sql/catalyst/src/test/resources/log4j.properties @@ -24,5 +24,5 @@ log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.eclipse.jetty=WARN -org.eclipse.jetty.LEVEL=WARN +log4j.logger.org.spark-project.jetty=WARN +org.spark-project.jetty.LEVEL=WARN diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala index a5e6b638d2150..53ad67372e024 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala @@ -34,7 +34,7 @@ import org.apache.spark.sql.types.NumericType @Experimental class GroupedData protected[sql](df: DataFrame, groupingExprs: Seq[Expression]) { - private[this] implicit def toDF(aggExprs: Seq[NamedExpression]): DataFrame = { + private[sql] implicit def toDF(aggExprs: Seq[NamedExpression]): DataFrame = { val namedGroupingExprs = groupingExprs.map { case expr: NamedExpression => expr case expr: Expression => Alias(expr, expr.prettyString)() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala new file mode 100644 index 0000000000000..d1ea7cc3e9162 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala @@ -0,0 +1,127 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.r + +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} + +import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} +import org.apache.spark.api.r.SerDe +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.expressions.{Alias, Expression, NamedExpression} +import org.apache.spark.sql.types.{DataType, StructType} +import org.apache.spark.sql.{Column, DataFrame, GroupedData, Row, SQLContext, SaveMode} + +private[r] object SQLUtils { + def createSQLContext(jsc: JavaSparkContext): SQLContext = { + new SQLContext(jsc) + } + + def getJavaSparkContext(sqlCtx: SQLContext): JavaSparkContext = { + new JavaSparkContext(sqlCtx.sparkContext) + } + + def toSeq[T](arr: Array[T]): Seq[T] = { + arr.toSeq + } + + def createDF(rdd: RDD[Array[Byte]], schemaString: String, sqlContext: SQLContext): DataFrame = { + val schema = DataType.fromJson(schemaString).asInstanceOf[StructType] + val num = schema.fields.size + val rowRDD = rdd.map(bytesToRow) + sqlContext.createDataFrame(rowRDD, schema) + } + + // A helper to include grouping columns in Agg() + def aggWithGrouping(gd: GroupedData, exprs: Column*): DataFrame = { + val aggExprs = exprs.map { col => + col.expr match { + case expr: NamedExpression => expr + case expr: Expression => Alias(expr, expr.simpleString)() + } + } + gd.toDF(aggExprs) + } + + def dfToRowRDD(df: DataFrame): JavaRDD[Array[Byte]] = { + df.map(r => rowToRBytes(r)) + } + + private[this] def bytesToRow(bytes: Array[Byte]): Row = { + val bis = new ByteArrayInputStream(bytes) + val dis = new DataInputStream(bis) + val num = SerDe.readInt(dis) + Row.fromSeq((0 until num).map { i => + SerDe.readObject(dis) + }.toSeq) + } + + private[this] def rowToRBytes(row: Row): Array[Byte] = { + val bos = new ByteArrayOutputStream() + val dos = new DataOutputStream(bos) + + SerDe.writeInt(dos, row.length) + (0 until row.length).map { idx => + val obj: Object = row(idx).asInstanceOf[Object] + SerDe.writeObject(dos, obj) + } + bos.toByteArray() + } + + def dfToCols(df: DataFrame): Array[Array[Byte]] = { + // localDF is Array[Row] + val localDF = df.collect() + val numCols = df.columns.length + // dfCols is Array[Array[Any]] + val dfCols = convertRowsToColumns(localDF, numCols) + + dfCols.map { col => + colToRBytes(col) + } + } + + def convertRowsToColumns(localDF: Array[Row], numCols: Int): Array[Array[Any]] = { + (0 until numCols).map { colIdx => + localDF.map { row => + row(colIdx) + } + }.toArray + } + + def colToRBytes(col: Array[Any]): Array[Byte] = { + val numRows = col.length + val bos = new ByteArrayOutputStream() + val dos = new DataOutputStream(bos) + + SerDe.writeInt(dos, numRows) + + col.map { item => + val obj: Object = item.asInstanceOf[Object] + SerDe.writeObject(dos, obj) + } + bos.toByteArray() + } + + def saveMode(mode: String): SaveMode = { + mode match { + case "append" => SaveMode.Append + case "overwrite" => SaveMode.Overwrite + case "error" => SaveMode.ErrorIfExists + case "ignore" => SaveMode.Ignore + } + } +} diff --git a/streaming/src/test/resources/log4j.properties b/streaming/src/test/resources/log4j.properties index 9697237bfa1a3..75e3b53a093f6 100644 --- a/streaming/src/test/resources/log4j.properties +++ b/streaming/src/test/resources/log4j.properties @@ -24,5 +24,5 @@ log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.eclipse.jetty=WARN +log4j.logger.org.spark-project.jetty=WARN diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 24a1e02795218..32bc4e5663062 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -469,6 +469,9 @@ private[spark] class ApplicationMaster( System.setProperty("spark.submit.pyFiles", PythonRunner.formatPaths(args.pyFiles).mkString(",")) } + if (args.primaryRFile != null && args.primaryRFile.endsWith(".R")) { + // TODO(davies): add R dependencies here + } val mainMethod = userClassLoader.loadClass(args.userClass) .getMethod("main", classOf[Array[String]]) diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala index e1a992af3aae7..ae6dc1094d724 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala @@ -25,6 +25,7 @@ class ApplicationMasterArguments(val args: Array[String]) { var userJar: String = null var userClass: String = null var primaryPyFile: String = null + var primaryRFile: String = null var pyFiles: String = null var userArgs: Seq[String] = Seq[String]() var executorMemory = 1024 @@ -54,6 +55,10 @@ class ApplicationMasterArguments(val args: Array[String]) { primaryPyFile = value args = tail + case ("--primary-r-file") :: value :: tail => + primaryRFile = value + args = tail + case ("--py-files") :: value :: tail => pyFiles = value args = tail @@ -79,6 +84,11 @@ class ApplicationMasterArguments(val args: Array[String]) { } } + if (primaryPyFile != null && primaryRFile != null) { + System.err.println("Cannot have primary-py-file and primary-r-file at the same time") + System.exit(-1) + } + userArgs = userArgsBuffer.readOnly } @@ -92,6 +102,7 @@ class ApplicationMasterArguments(val args: Array[String]) { | --jar JAR_PATH Path to your application's JAR file | --class CLASS_NAME Name of your application's main class | --primary-py-file A main Python file + | --primary-r-file A main R file | --py-files PY_FILES Comma-separated list of .zip, .egg, or .py files to | place on the PYTHONPATH for Python apps. | --args ARGS Arguments to be passed to your application's main class. diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 7219852c0a752..c1effd3c8a718 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -491,6 +491,12 @@ private[spark] class Client( } else { Nil } + val primaryRFile = + if (args.primaryRFile != null) { + Seq("--primary-r-file", args.primaryRFile) + } else { + Nil + } val amClass = if (isClusterMode) { Class.forName("org.apache.spark.deploy.yarn.ApplicationMaster").getName @@ -500,12 +506,15 @@ private[spark] class Client( if (args.primaryPyFile != null && args.primaryPyFile.endsWith(".py")) { args.userArgs = ArrayBuffer(args.primaryPyFile, args.pyFiles) ++ args.userArgs } + if (args.primaryRFile != null && args.primaryRFile.endsWith(".R")) { + args.userArgs = ArrayBuffer(args.primaryRFile) ++ args.userArgs + } val userArgs = args.userArgs.flatMap { arg => Seq("--arg", YarnSparkHadoopUtil.escapeForShell(arg)) } val amArgs = - Seq(amClass) ++ userClass ++ userJar ++ primaryPyFile ++ pyFiles ++ userArgs ++ - Seq( + Seq(amClass) ++ userClass ++ userJar ++ primaryPyFile ++ pyFiles ++ primaryRFile ++ + userArgs ++ Seq( "--executor-memory", args.executorMemory.toString + "m", "--executor-cores", args.executorCores.toString, "--num-executors ", args.numExecutors.toString) diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala index 3bc7eb1abf341..da6798cb1b279 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala @@ -32,6 +32,7 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf) var userClass: String = null var pyFiles: String = null var primaryPyFile: String = null + var primaryRFile: String = null var userArgs: ArrayBuffer[String] = new ArrayBuffer[String]() var executorMemory = 1024 // MB var executorCores = 1 @@ -150,6 +151,10 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf) primaryPyFile = value args = tail + case ("--primary-r-file") :: value :: tail => + primaryRFile = value + args = tail + case ("--args" | "--arg") :: value :: tail => if (args(0) == "--args") { println("--args is deprecated. Use --arg instead.") @@ -228,6 +233,11 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf) throw new IllegalArgumentException(getUsageMessage(args)) } } + + if (primaryPyFile != null && primaryRFile != null) { + throw new IllegalArgumentException("Cannot have primary-py-file and primary-r-file" + + " at the same time") + } } private def getUsageMessage(unknownParam: List[String] = null): String = { @@ -240,6 +250,7 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf) | mode) | --class CLASS_NAME Name of your application's main class (required) | --primary-py-file A main Python file + | --primary-r-file A main R file | --arg ARG Argument to be passed to your application's main class. | Multiple invocations are possible, each will be passed in order. | --num-executors NUM Number of executors to start (Default: 2) diff --git a/yarn/src/test/resources/log4j.properties b/yarn/src/test/resources/log4j.properties index aab41fa49430f..6b8a5dbf6373e 100644 --- a/yarn/src/test/resources/log4j.properties +++ b/yarn/src/test/resources/log4j.properties @@ -24,5 +24,5 @@ log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.eclipse.jetty=WARN +log4j.logger.org.spark-project.jetty=WARN log4j.logger.org.apache.hadoop=WARN