diff --git a/.rat-excludes b/.rat-excludes index 769defbac11b7..8c61e67a0c7d1 100644 --- a/.rat-excludes +++ b/.rat-excludes @@ -1,4 +1,5 @@ target +cache .gitignore .gitattributes .project @@ -18,6 +19,7 @@ fairscheduler.xml.template spark-defaults.conf.template log4j.properties log4j.properties.template +metrics.properties metrics.properties.template slaves slaves.template diff --git a/README.md b/README.md index 16628bd406775..af02339578195 100644 --- a/README.md +++ b/README.md @@ -26,7 +26,7 @@ To build Spark and its example programs, run: (You do not need to do this if you downloaded a pre-built package.) More detailed documentation is available from the project site, at -["Building Spark with Maven"](http://spark.apache.org/docs/latest/building-spark.html). +["Building Spark"](http://spark.apache.org/docs/latest/building-spark.html). ## Interactive Scala Shell diff --git a/assembly/pom.xml b/assembly/pom.xml index b2a9d0780ee2b..3d1ed0dd8a7bd 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -36,19 +36,9 @@ scala-${scala.binary.version} spark-assembly-${project.version}-hadoop${hadoop.version}.jar ${project.build.directory}/${spark.jar.dir}/${spark.jar.basename} - spark - /usr/share/spark - root - 744 - - - com.google.guava - guava - compile - org.apache.spark spark-core_${scala.binary.version} @@ -124,6 +114,16 @@ META-INF/*.RSA + + + org.jblas:jblas + + + lib/static/Linux/i386/** + lib/static/Mac OS X/** + lib/static/Windows/** + + @@ -133,20 +133,6 @@ shade - - - com.google - org.spark-project.guava - - com.google.common.** - - - com/google/common/base/Absent* - com/google/common/base/Optional* - com/google/common/base/Present* - - - @@ -237,113 +223,6 @@ - - deb - - - - org.codehaus.mojo - buildnumber-maven-plugin - 1.2 - - - validate - - create - - - 8 - - - - - - org.vafer - jdeb - 0.11 - - - package - - jdeb - - - ${project.build.directory}/${deb.pkg.name}_${project.version}-${buildNumber}_all.deb - false - gzip - - - ${spark.jar} - file - - perm - ${deb.user} - ${deb.user} - ${deb.install.path}/jars - - - - ${basedir}/src/deb/RELEASE - file - - perm - ${deb.user} - ${deb.user} - ${deb.install.path} - - - - ${basedir}/../conf - directory - - perm - ${deb.user} - ${deb.user} - ${deb.install.path}/conf - 744 - - - - ${basedir}/../bin - directory - - perm - ${deb.user} - ${deb.user} - ${deb.install.path}/bin - ${deb.bin.filemode} - - - - ${basedir}/../sbin - directory - - perm - ${deb.user} - ${deb.user} - ${deb.install.path}/sbin - 744 - - - - ${basedir}/../python - directory - - perm - ${deb.user} - ${deb.user} - ${deb.install.path}/python - 744 - - - - - - - - - - kinesis-asl diff --git a/assembly/src/deb/RELEASE b/assembly/src/deb/RELEASE deleted file mode 100644 index aad50ee73aa45..0000000000000 --- a/assembly/src/deb/RELEASE +++ /dev/null @@ -1,2 +0,0 @@ -compute-classpath.sh uses the existence of this file to decide whether to put the assembly jar on the -classpath or instead to use classfiles in the source tree. \ No newline at end of file diff --git a/assembly/src/deb/control/control b/assembly/src/deb/control/control deleted file mode 100644 index a6b4471d485f4..0000000000000 --- a/assembly/src/deb/control/control +++ /dev/null @@ -1,8 +0,0 @@ -Package: [[deb.pkg.name]] -Version: [[version]]-[[buildNumber]] -Section: misc -Priority: extra -Architecture: all -Maintainer: Matei Zaharia -Description: [[name]] -Distribution: development diff --git a/bin/compute-classpath.sh b/bin/compute-classpath.sh index 8f3b396ffd086..f4f6b7b909490 100755 --- a/bin/compute-classpath.sh +++ b/bin/compute-classpath.sh @@ -50,8 +50,8 @@ fi if [ -n "$SPARK_PREPEND_CLASSES" ]; then echo "NOTE: SPARK_PREPEND_CLASSES is set, placing locally compiled Spark"\ "classes ahead of assembly." >&2 + # Spark classes CLASSPATH="$CLASSPATH:$FWDIR/core/target/scala-$SPARK_SCALA_VERSION/classes" - CLASSPATH="$CLASSPATH:$FWDIR/core/target/jars/*" CLASSPATH="$CLASSPATH:$FWDIR/repl/target/scala-$SPARK_SCALA_VERSION/classes" CLASSPATH="$CLASSPATH:$FWDIR/mllib/target/scala-$SPARK_SCALA_VERSION/classes" CLASSPATH="$CLASSPATH:$FWDIR/bagel/target/scala-$SPARK_SCALA_VERSION/classes" @@ -63,6 +63,8 @@ if [ -n "$SPARK_PREPEND_CLASSES" ]; then CLASSPATH="$CLASSPATH:$FWDIR/sql/hive/target/scala-$SPARK_SCALA_VERSION/classes" CLASSPATH="$CLASSPATH:$FWDIR/sql/hive-thriftserver/target/scala-$SPARK_SCALA_VERSION/classes" CLASSPATH="$CLASSPATH:$FWDIR/yarn/stable/target/scala-$SPARK_SCALA_VERSION/classes" + # Jars for shaded deps in their original form (copied here during build) + CLASSPATH="$CLASSPATH:$FWDIR/core/target/jars/*" fi # Use spark-assembly jar from either RELEASE or assembly directory @@ -72,22 +74,25 @@ else assembly_folder="$ASSEMBLY_DIR" fi -num_jars="$(ls "$assembly_folder" | grep "spark-assembly.*hadoop.*\.jar$" | wc -l)" -if [ "$num_jars" -eq "0" ]; then - echo "Failed to find Spark assembly in $assembly_folder" - echo "You need to build Spark before running this program." - exit 1 -fi +num_jars=0 + +for f in "${assembly_folder}"/spark-assembly*hadoop*.jar; do + if [[ ! -e "$f" ]]; then + echo "Failed to find Spark assembly in $assembly_folder" 1>&2 + echo "You need to build Spark before running this program." 1>&2 + exit 1 + fi + ASSEMBLY_JAR="$f" + num_jars=$((num_jars+1)) +done + if [ "$num_jars" -gt "1" ]; then - jars_list=$(ls "$assembly_folder" | grep "spark-assembly.*hadoop.*.jar$") - echo "Found multiple Spark assembly jars in $assembly_folder:" - echo "$jars_list" - echo "Please remove all but one jar." + echo "Found multiple Spark assembly jars in $assembly_folder:" 1>&2 + ls "${assembly_folder}"/spark-assembly*hadoop*.jar 1>&2 + echo "Please remove all but one jar." 1>&2 exit 1 fi -ASSEMBLY_JAR="$(ls "$assembly_folder"/spark-assembly*hadoop*.jar 2>/dev/null)" - # Verify that versions of java used to build the jars and run Spark are compatible jar_error_check=$("$JAR_CMD" -tf "$ASSEMBLY_JAR" nonexistent/class/path 2>&1) if [[ "$jar_error_check" =~ "invalid CEN header" ]]; then diff --git a/bin/run-example b/bin/run-example index 3d932509426fc..a106411392e06 100755 --- a/bin/run-example +++ b/bin/run-example @@ -35,17 +35,32 @@ else fi if [ -f "$FWDIR/RELEASE" ]; then - export SPARK_EXAMPLES_JAR="`ls "$FWDIR"/lib/spark-examples-*hadoop*.jar`" -elif [ -e "$EXAMPLES_DIR"/target/scala-$SPARK_SCALA_VERSION/spark-examples-*hadoop*.jar ]; then - export SPARK_EXAMPLES_JAR="`ls "$EXAMPLES_DIR"/target/scala-$SPARK_SCALA_VERSION/spark-examples-*hadoop*.jar`" + JAR_PATH="${FWDIR}/lib" +else + JAR_PATH="${EXAMPLES_DIR}/target/scala-${SPARK_SCALA_VERSION}" fi -if [[ -z "$SPARK_EXAMPLES_JAR" ]]; then - echo "Failed to find Spark examples assembly in $FWDIR/lib or $FWDIR/examples/target" 1>&2 - echo "You need to build Spark before running this program" 1>&2 +JAR_COUNT=0 + +for f in "${JAR_PATH}"/spark-examples-*hadoop*.jar; do + if [[ ! -e "$f" ]]; then + echo "Failed to find Spark examples assembly in $FWDIR/lib or $FWDIR/examples/target" 1>&2 + echo "You need to build Spark before running this program" 1>&2 + exit 1 + fi + SPARK_EXAMPLES_JAR="$f" + JAR_COUNT=$((JAR_COUNT+1)) +done + +if [ "$JAR_COUNT" -gt "1" ]; then + echo "Found multiple Spark examples assembly jars in ${JAR_PATH}" 1>&2 + ls "${JAR_PATH}"/spark-examples-*hadoop*.jar 1>&2 + echo "Please remove all but one jar." 1>&2 exit 1 fi +export SPARK_EXAMPLES_JAR + EXAMPLE_MASTER=${MASTER:-"local[*]"} if [[ ! $EXAMPLE_CLASS == org.apache.spark.examples* ]]; then diff --git a/bin/spark-class b/bin/spark-class index 0d58d95c1aee3..2f0441bb3c1c2 100755 --- a/bin/spark-class +++ b/bin/spark-class @@ -29,6 +29,7 @@ FWDIR="$(cd "`dirname "$0"`"/..; pwd)" # Export this as SPARK_HOME export SPARK_HOME="$FWDIR" +export SPARK_CONF_DIR="${SPARK_CONF_DIR:-"$SPARK_HOME/conf"}" . "$FWDIR"/bin/load-spark-env.sh @@ -71,6 +72,8 @@ case "$1" in 'org.apache.spark.executor.MesosExecutorBackend') OUR_JAVA_OPTS="$SPARK_JAVA_OPTS $SPARK_EXECUTOR_OPTS" OUR_JAVA_MEM=${SPARK_EXECUTOR_MEMORY:-$DEFAULT_MEM} + export PYTHONPATH="$FWDIR/python:$PYTHONPATH" + export PYTHONPATH="$FWDIR/python/lib/py4j-0.8.2.1-src.zip:$PYTHONPATH" ;; # Spark submit uses SPARK_JAVA_OPTS + SPARK_SUBMIT_OPTS + @@ -118,8 +121,8 @@ fi JAVA_OPTS="$JAVA_OPTS -Xms$OUR_JAVA_MEM -Xmx$OUR_JAVA_MEM" # Load extra JAVA_OPTS from conf/java-opts, if it exists -if [ -e "$FWDIR/conf/java-opts" ] ; then - JAVA_OPTS="$JAVA_OPTS `cat "$FWDIR"/conf/java-opts`" +if [ -e "$SPARK_CONF_DIR/java-opts" ] ; then + JAVA_OPTS="$JAVA_OPTS `cat "$SPARK_CONF_DIR"/java-opts`" fi # Attention: when changing the way the JAVA_OPTS are assembled, the change must be reflected in CommandUtils.scala! @@ -148,7 +151,7 @@ fi if [[ "$1" =~ org.apache.spark.tools.* ]]; then if test -z "$SPARK_TOOLS_JAR"; then echo "Failed to find Spark Tools Jar in $FWDIR/tools/target/scala-$SPARK_SCALA_VERSION/" 1>&2 - echo "You need to build Spark before running $1." 1>&2 + echo "You need to run \"build/sbt tools/package\" before running $1." 1>&2 exit 1 fi CLASSPATH="$CLASSPATH:$SPARK_TOOLS_JAR" diff --git a/bin/spark-shell.cmd b/bin/spark-shell.cmd old mode 100755 new mode 100644 diff --git a/bin/spark-submit2.cmd b/bin/spark-submit2.cmd index 12244a9cb04fb..446cbc74b74f9 100644 --- a/bin/spark-submit2.cmd +++ b/bin/spark-submit2.cmd @@ -25,7 +25,7 @@ set ORIG_ARGS=%* rem Reset the values of all variables used set SPARK_SUBMIT_DEPLOY_MODE=client -if not defined %SPARK_CONF_DIR% ( +if [%SPARK_CONF_DIR%] == [] ( set SPARK_CONF_DIR=%SPARK_HOME%\conf ) set SPARK_SUBMIT_PROPERTIES_FILE=%SPARK_CONF_DIR%\spark-defaults.conf diff --git a/bin/utils.sh b/bin/utils.sh index 22ea2b9a6d586..748dbe345a74c 100755 --- a/bin/utils.sh +++ b/bin/utils.sh @@ -26,16 +26,17 @@ function gatherSparkSubmitOpts() { exit 1 fi - # NOTE: If you add or remove spark-sumbmit options, + # NOTE: If you add or remove spark-submit options, # modify NOT ONLY this script but also SparkSubmitArgument.scala SUBMISSION_OPTS=() APPLICATION_OPTS=() while (($#)); do case "$1" in - --master | --deploy-mode | --class | --name | --jars | --py-files | --files | \ - --conf | --properties-file | --driver-memory | --driver-java-options | \ + --master | --deploy-mode | --class | --name | --jars | --packages | --py-files | --files | \ + --conf | --repositories | --properties-file | --driver-memory | --driver-java-options | \ --driver-library-path | --driver-class-path | --executor-memory | --driver-cores | \ - --total-executor-cores | --executor-cores | --queue | --num-executors | --archives) + --total-executor-cores | --executor-cores | --queue | --num-executors | --archives | \ + --proxy-user) if [[ $# -lt 2 ]]; then "$SUBMIT_USAGE_FUNCTION" exit 1; diff --git a/bin/windows-utils.cmd b/bin/windows-utils.cmd index 1082a952dac99..0cf9e87ca554b 100644 --- a/bin/windows-utils.cmd +++ b/bin/windows-utils.cmd @@ -32,7 +32,8 @@ SET opts="\<--master\> \<--deploy-mode\> \<--class\> \<--name\> \<--jars\> \<--p SET opts="%opts:~1,-1% \<--conf\> \<--properties-file\> \<--driver-memory\> \<--driver-java-options\>" SET opts="%opts:~1,-1% \<--driver-library-path\> \<--driver-class-path\> \<--executor-memory\>" SET opts="%opts:~1,-1% \<--driver-cores\> \<--total-executor-cores\> \<--executor-cores\> \<--queue\>" -SET opts="%opts:~1,-1% \<--num-executors\> \<--archives\>" +SET opts="%opts:~1,-1% \<--num-executors\> \<--archives\> \<--packages\> \<--repositories\>" +SET opts="%opts:~1,-1% \<--proxy-user\>" echo %1 | findstr %opts% >nul if %ERRORLEVEL% equ 0 ( diff --git a/build/mvn b/build/mvn index 43471f83e904c..3561110a4c019 100755 --- a/build/mvn +++ b/build/mvn @@ -21,6 +21,8 @@ _DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" # Preserve the calling directory _CALLING_DIR="$(pwd)" +# Options used during compilation +_COMPILE_JVM_OPTS="-Xmx2g -XX:MaxPermSize=512M -XX:ReservedCodeCacheSize=512m" # Installs any application tarball given a URL, the expected tarball name, # and, optionally, a checkable binary path to determine if the binary has @@ -34,25 +36,25 @@ install_app() { local binary="${_DIR}/$3" # setup `curl` and `wget` silent options if we're running on Jenkins - local curl_opts="" + local curl_opts="-L" local wget_opts="" if [ -n "$AMPLAB_JENKINS" ]; then - curl_opts="-s" - wget_opts="--quiet" + curl_opts="-s ${curl_opts}" + wget_opts="--quiet ${wget_opts}" else - curl_opts="--progress-bar" - wget_opts="--progress=bar:force" + curl_opts="--progress-bar ${curl_opts}" + wget_opts="--progress=bar:force ${wget_opts}" fi if [ -z "$3" -o ! -f "$binary" ]; then # check if we already have the tarball # check if we have curl installed # download application - [ ! -f "${local_tarball}" ] && [ -n "`which curl 2>/dev/null`" ] && \ + [ ! -f "${local_tarball}" ] && [ $(command -v curl) ] && \ echo "exec: curl ${curl_opts} ${remote_tarball}" && \ curl ${curl_opts} "${remote_tarball}" > "${local_tarball}" # if the file still doesn't exist, lets try `wget` and cross our fingers - [ ! -f "${local_tarball}" ] && [ -n "`which wget 2>/dev/null`" ] && \ + [ ! -f "${local_tarball}" ] && [ $(command -v wget) ] && \ echo "exec: wget ${wget_opts} ${remote_tarball}" && \ wget ${wget_opts} -O "${local_tarball}" "${remote_tarball}" # if both were unsuccessful, exit @@ -68,10 +70,10 @@ install_app() { # Install maven under the build/ folder install_mvn() { install_app \ - "http://apache.claz.org/maven/maven-3/3.2.3/binaries" \ - "apache-maven-3.2.3-bin.tar.gz" \ - "apache-maven-3.2.3/bin/mvn" - MVN_BIN="${_DIR}/apache-maven-3.2.3/bin/mvn" + "http://archive.apache.org/dist/maven/maven-3/3.2.5/binaries" \ + "apache-maven-3.2.5-bin.tar.gz" \ + "apache-maven-3.2.5/bin/mvn" + MVN_BIN="${_DIR}/apache-maven-3.2.5/bin/mvn" } # Install zinc under the build/ folder @@ -136,6 +138,7 @@ cd "${_CALLING_DIR}" # Now that zinc is ensured to be installed, check its status and, if its # not running or just installed, start it if [ -n "${ZINC_INSTALL_FLAG}" -o -z "`${ZINC_BIN} -status`" ]; then + export ZINC_OPTS=${ZINC_OPTS:-"$_COMPILE_JVM_OPTS"} ${ZINC_BIN} -shutdown ${ZINC_BIN} -start -port ${ZINC_PORT} \ -scala-compiler "${SCALA_COMPILER}" \ @@ -143,7 +146,7 @@ if [ -n "${ZINC_INSTALL_FLAG}" -o -z "`${ZINC_BIN} -status`" ]; then fi # Set any `mvn` options if not already present -export MAVEN_OPTS=${MAVEN_OPTS:-"-Xmx2g -XX:MaxPermSize=512M -XX:ReservedCodeCacheSize=512m"} +export MAVEN_OPTS=${MAVEN_OPTS:-"$_COMPILE_JVM_OPTS"} # Last, call the `mvn` command as usual ${MVN_BIN} "$@" diff --git a/build/sbt b/build/sbt index 28ebb64f7197c..cc3203d79bccd 100755 --- a/build/sbt +++ b/build/sbt @@ -125,4 +125,32 @@ loadConfigFile() { [[ -f "$etc_sbt_opts_file" ]] && set -- $(loadConfigFile "$etc_sbt_opts_file") "$@" [[ -f "$sbt_opts_file" ]] && set -- $(loadConfigFile "$sbt_opts_file") "$@" +exit_status=127 +saved_stty="" + +restoreSttySettings() { + stty $saved_stty + saved_stty="" +} + +onExit() { + if [[ "$saved_stty" != "" ]]; then + restoreSttySettings + fi + exit $exit_status +} + +saveSttySettings() { + saved_stty=$(stty -g 2>/dev/null) + if [[ ! $? ]]; then + saved_stty="" + fi +} + +saveSttySettings +trap onExit INT + run "$@" + +exit_status=$? +onExit diff --git a/build/sbt-launch-lib.bash b/build/sbt-launch-lib.bash index f5df439effb01..504be48b358fa 100755 --- a/build/sbt-launch-lib.bash +++ b/build/sbt-launch-lib.bash @@ -50,9 +50,9 @@ acquire_sbt_jar () { # Download printf "Attempting to fetch sbt\n" JAR_DL="${JAR}.part" - if hash curl 2>/dev/null; then + if [ $(command -v curl) ]; then (curl --silent ${URL1} > "${JAR_DL}" || curl --silent ${URL2} > "${JAR_DL}") && mv "${JAR_DL}" "${JAR}" - elif hash wget 2>/dev/null; then + elif [ $(command -v wget) ]; then (wget --quiet ${URL1} -O "${JAR_DL}" || wget --quiet ${URL2} -O "${JAR_DL}") && mv "${JAR_DL}" "${JAR}" else printf "You do not have curl or wget installed, please install sbt manually from http://www.scala-sbt.org/\n" @@ -81,7 +81,7 @@ execRunner () { echo "" } - exec "$@" + "$@" } addJava () { diff --git a/conf/metrics.properties.template b/conf/metrics.properties.template index 96b6844f0aabb..2e0cb5db170ac 100644 --- a/conf/metrics.properties.template +++ b/conf/metrics.properties.template @@ -87,6 +87,7 @@ # period 10 Poll period # unit seconds Units of poll period # prefix EMPTY STRING Prefix to prepend to metric name +# protocol tcp Protocol ("tcp" or "udp") to use ## Examples # Enable JmxSink for all instances by class name @@ -121,6 +122,15 @@ #worker.sink.csv.unit=minutes +# Enable Slf4jSink for all instances by class name +#*.sink.slf4j.class=org.apache.spark.metrics.sink.Slf4jSink + +# Polling period for Slf4JSink +#*.sink.sl4j.period=1 + +#*.sink.sl4j.unit=minutes + + # Enable jvm source for instance master, worker, driver and executor #master.source.jvm.class=org.apache.spark.metrics.source.JvmSource diff --git a/core/pom.xml b/core/pom.xml index d9a49c9e08afc..c993781c0e0d6 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -34,6 +34,10 @@ Spark Project Core http://spark.apache.org/ + + com.google.guava + guava + com.twitter chill_${scala.binary.version} @@ -90,32 +94,52 @@ org.apache.curator curator-recipes + + org.eclipse.jetty jetty-plus + compile org.eclipse.jetty jetty-security + compile org.eclipse.jetty jetty-util + compile org.eclipse.jetty jetty-server + compile - - com.google.guava - guava + org.eclipse.jetty + jetty-http compile + + org.eclipse.jetty + jetty-continuation + compile + + + org.eclipse.jetty + jetty-servlet + compile + + + + org.eclipse.jetty.orbit + javax.servlet + ${orbit.version} + + org.apache.commons commons-lang3 @@ -204,26 +228,45 @@ stream - com.codahale.metrics + io.dropwizard.metrics metrics-core - com.codahale.metrics + io.dropwizard.metrics metrics-jvm - com.codahale.metrics + io.dropwizard.metrics metrics-json - com.codahale.metrics + io.dropwizard.metrics metrics-graphite + + com.fasterxml.jackson.core + jackson-databind + + + com.fasterxml.jackson.module + jackson-module-scala_2.10 + org.apache.derby derby test + + org.apache.ivy + ivy + ${ivy.version} + + + oro + + oro + ${oro.version} + org.tachyonproject tachyon-client @@ -286,16 +329,6 @@ scalacheck_${scala.binary.version} test - - org.easymock - easymockclassextension - test - - - asm - asm - test - junit junit @@ -350,59 +383,28 @@ true - - org.apache.maven.plugins - maven-shade-plugin - - - package - - shade - - - false - - - com.google.guava:guava - - - - - - com.google.guava:guava - - com/google/common/base/Absent* - com/google/common/base/Optional* - com/google/common/base/Present* - - - - - - - - org.apache.maven.plugins maven-dependency-plugin + copy-dependencies package copy-dependencies - + ${project.build.directory} false false true true - guava + + guava,jetty-io,jetty-servlet,jetty-continuation,jetty-http,jetty-plus,jetty-util,jetty-server + true diff --git a/core/src/main/java/org/apache/spark/SparkFirehoseListener.java b/core/src/main/java/org/apache/spark/SparkFirehoseListener.java new file mode 100644 index 0000000000000..fbc5666959055 --- /dev/null +++ b/core/src/main/java/org/apache/spark/SparkFirehoseListener.java @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark; + +import org.apache.spark.scheduler.*; + +/** + * Class that allows users to receive all SparkListener events. + * Users should override the onEvent method. + * + * This is a concrete Java class in order to ensure that we don't forget to update it when adding + * new methods to SparkListener: forgetting to add a method will result in a compilation error (if + * this was a concrete Scala class, default implementations of new event handlers would be inherited + * from the SparkListener trait). + */ +public class SparkFirehoseListener implements SparkListener { + + public void onEvent(SparkListenerEvent event) { } + + @Override + public final void onStageCompleted(SparkListenerStageCompleted stageCompleted) { + onEvent(stageCompleted); + } + + @Override + public final void onStageSubmitted(SparkListenerStageSubmitted stageSubmitted) { + onEvent(stageSubmitted); + } + + @Override + public final void onTaskStart(SparkListenerTaskStart taskStart) { + onEvent(taskStart); + } + + @Override + public final void onTaskGettingResult(SparkListenerTaskGettingResult taskGettingResult) { + onEvent(taskGettingResult); + } + + @Override + public final void onTaskEnd(SparkListenerTaskEnd taskEnd) { + onEvent(taskEnd); + } + + @Override + public final void onJobStart(SparkListenerJobStart jobStart) { + onEvent(jobStart); + } + + @Override + public final void onJobEnd(SparkListenerJobEnd jobEnd) { + onEvent(jobEnd); + } + + @Override + public final void onEnvironmentUpdate(SparkListenerEnvironmentUpdate environmentUpdate) { + onEvent(environmentUpdate); + } + + @Override + public final void onBlockManagerAdded(SparkListenerBlockManagerAdded blockManagerAdded) { + onEvent(blockManagerAdded); + } + + @Override + public final void onBlockManagerRemoved(SparkListenerBlockManagerRemoved blockManagerRemoved) { + onEvent(blockManagerRemoved); + } + + @Override + public final void onUnpersistRDD(SparkListenerUnpersistRDD unpersistRDD) { + onEvent(unpersistRDD); + } + + @Override + public final void onApplicationStart(SparkListenerApplicationStart applicationStart) { + onEvent(applicationStart); + } + + @Override + public final void onApplicationEnd(SparkListenerApplicationEnd applicationEnd) { + onEvent(applicationEnd); + } + + @Override + public final void onExecutorMetricsUpdate( + SparkListenerExecutorMetricsUpdate executorMetricsUpdate) { + onEvent(executorMetricsUpdate); + } + + @Override + public final void onExecutorAdded(SparkListenerExecutorAdded executorAdded) { + onEvent(executorAdded); + } + + @Override + public final void onExecutorRemoved(SparkListenerExecutorRemoved executorRemoved) { + onEvent(executorRemoved); + } +} diff --git a/core/src/main/java/org/apache/spark/TaskContext.java b/core/src/main/java/org/apache/spark/TaskContext.java deleted file mode 100644 index 095f9fb94fdf0..0000000000000 --- a/core/src/main/java/org/apache/spark/TaskContext.java +++ /dev/null @@ -1,126 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark; - -import java.io.Serializable; - -import scala.Function0; -import scala.Function1; -import scala.Unit; - -import org.apache.spark.annotation.DeveloperApi; -import org.apache.spark.executor.TaskMetrics; -import org.apache.spark.util.TaskCompletionListener; - -/** - * Contextual information about a task which can be read or mutated during - * execution. To access the TaskContext for a running task use - * TaskContext.get(). - */ -public abstract class TaskContext implements Serializable { - /** - * Return the currently active TaskContext. This can be called inside of - * user functions to access contextual information about running tasks. - */ - public static TaskContext get() { - return taskContext.get(); - } - - private static ThreadLocal taskContext = - new ThreadLocal(); - - static void setTaskContext(TaskContext tc) { - taskContext.set(tc); - } - - static void unset() { - taskContext.remove(); - } - - /** - * Whether the task has completed. - */ - public abstract boolean isCompleted(); - - /** - * Whether the task has been killed. - */ - public abstract boolean isInterrupted(); - - /** @deprecated use {@link #isRunningLocally()} */ - @Deprecated - public abstract boolean runningLocally(); - - public abstract boolean isRunningLocally(); - - /** - * Add a (Java friendly) listener to be executed on task completion. - * This will be called in all situation - success, failure, or cancellation. - * An example use is for HadoopRDD to register a callback to close the input stream. - */ - public abstract TaskContext addTaskCompletionListener(TaskCompletionListener listener); - - /** - * Add a listener in the form of a Scala closure to be executed on task completion. - * This will be called in all situations - success, failure, or cancellation. - * An example use is for HadoopRDD to register a callback to close the input stream. - */ - public abstract TaskContext addTaskCompletionListener(final Function1 f); - - /** - * Add a callback function to be executed on task completion. An example use - * is for HadoopRDD to register a callback to close the input stream. - * Will be called in any situation - success, failure, or cancellation. - * - * @deprecated use {@link #addTaskCompletionListener(scala.Function1)} - * - * @param f Callback function. - */ - @Deprecated - public abstract void addOnCompleteCallback(final Function0 f); - - /** - * The ID of the stage that this task belong to. - */ - public abstract int stageId(); - - /** - * The ID of the RDD partition that is computed by this task. - */ - public abstract int partitionId(); - - /** - * How many times this task has been attempted. The first task attempt will be assigned - * attemptNumber = 0, and subsequent attempts will have increasing attempt numbers. - */ - public abstract int attemptNumber(); - - /** @deprecated use {@link #taskAttemptId()}; it was renamed to avoid ambiguity. */ - @Deprecated - public abstract long attemptId(); - - /** - * An ID that is unique to this task attempt (within the same SparkContext, no two task attempts - * will share the same attempt ID). This is roughly equivalent to Hadoop's TaskAttemptID. - */ - public abstract long taskAttemptId(); - - /** ::DeveloperApi:: */ - @DeveloperApi - public abstract TaskMetrics taskMetrics(); -} diff --git a/core/src/main/resources/org/apache/spark/ui/static/webui.css b/core/src/main/resources/org/apache/spark/ui/static/webui.css index 5751964b792ce..6c37cc8b98236 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/webui.css +++ b/core/src/main/resources/org/apache/spark/ui/static/webui.css @@ -19,6 +19,7 @@ height: 50px; font-size: 15px; margin-bottom: 15px; + min-width: 1200px } .navbar .navbar-inner { @@ -39,12 +40,12 @@ .navbar .nav > li a { height: 30px; - line-height: 30px; + line-height: 2; } .navbar-text { height: 50px; - line-height: 50px; + line-height: 3.3; } table.sortable thead { @@ -102,6 +103,12 @@ span.expand-details { float: right; } +span.rest-uri { + font-size: 10pt; + font-style: italic; + color: gray; +} + pre { font-size: 0.8em; } @@ -120,6 +127,14 @@ pre { border: none; } +.description-input { + overflow: hidden; + text-overflow: ellipsis; + width: 100%; + white-space: nowrap; + display: block; +} + .stacktrace-details { max-height: 300px; overflow-y: auto; @@ -170,7 +185,7 @@ span.additional-metric-title { } .version { - line-height: 30px; + line-height: 2.5; vertical-align: bottom; font-size: 12px; padding: 0; @@ -181,6 +196,7 @@ span.additional-metric-title { /* Hide all additional metrics by default. This is done here rather than using JavaScript to * avoid slow page loads for stage pages with large numbers (e.g., thousands) of tasks. */ -.scheduler_delay, .deserialization_time, .serialization_time, .getting_result_time { +.scheduler_delay, .deserialization_time, .fetch_wait_time, .shuffle_read_remote, +.serialization_time, .getting_result_time { display: none; } diff --git a/core/src/main/scala/org/apache/spark/Accumulators.scala b/core/src/main/scala/org/apache/spark/Accumulators.scala index 5f31bfba3f8d6..30f0ccd73ccca 100644 --- a/core/src/main/scala/org/apache/spark/Accumulators.scala +++ b/core/src/main/scala/org/apache/spark/Accumulators.scala @@ -23,6 +23,7 @@ import java.lang.ThreadLocal import scala.collection.generic.Growable import scala.collection.mutable.Map +import scala.ref.WeakReference import scala.reflect.ClassTag import org.apache.spark.serializer.JavaSerializer @@ -280,10 +281,12 @@ object AccumulatorParam { // TODO: The multi-thread support in accumulators is kind of lame; check // if there's a more intuitive way of doing it right private[spark] object Accumulators { - // TODO: Use soft references? => need to make readObject work properly then - val originals = Map[Long, Accumulable[_, _]]() - val localAccums = new ThreadLocal[Map[Long, Accumulable[_, _]]]() { - override protected def initialValue() = Map[Long, Accumulable[_, _]]() + // Store a WeakReference instead of a StrongReference because this way accumulators can be + // appropriately garbage collected during long-running jobs and release memory + type WeakAcc = WeakReference[Accumulable[_, _]] + val originals = Map[Long, WeakAcc]() + val localAccums = new ThreadLocal[Map[Long, WeakAcc]]() { + override protected def initialValue() = Map[Long, WeakAcc]() } var lastId: Long = 0 @@ -294,9 +297,9 @@ private[spark] object Accumulators { def register(a: Accumulable[_, _], original: Boolean): Unit = synchronized { if (original) { - originals(a.id) = a + originals(a.id) = new WeakAcc(a) } else { - localAccums.get()(a.id) = a + localAccums.get()(a.id) = new WeakAcc(a) } } @@ -307,11 +310,22 @@ private[spark] object Accumulators { } } + def remove(accId: Long) { + synchronized { + originals.remove(accId) + } + } + // Get the values of the local accumulators for the current thread (by ID) def values: Map[Long, Any] = synchronized { val ret = Map[Long, Any]() for ((id, accum) <- localAccums.get) { - ret(id) = accum.localValue + // Since we are now storing weak references, we must check whether the underlying data + // is valid. + ret(id) = accum.get match { + case Some(values) => values.localValue + case None => None + } } return ret } @@ -320,7 +334,13 @@ private[spark] object Accumulators { def add(values: Map[Long, Any]): Unit = synchronized { for ((id, value) <- values) { if (originals.contains(id)) { - originals(id).asInstanceOf[Accumulable[Any, Any]] ++= value + // Since we are now storing weak references, we must check whether the underlying data + // is valid. + originals(id).get match { + case Some(accum) => accum.asInstanceOf[Accumulable[Any, Any]] ++= value + case None => + throw new IllegalAccessError("Attempted to access garbage collected Accumulator.") + } } } } diff --git a/core/src/main/scala/org/apache/spark/Aggregator.scala b/core/src/main/scala/org/apache/spark/Aggregator.scala index 09eb9605fb799..3b684bbeceaf2 100644 --- a/core/src/main/scala/org/apache/spark/Aggregator.scala +++ b/core/src/main/scala/org/apache/spark/Aggregator.scala @@ -61,8 +61,8 @@ case class Aggregator[K, V, C] ( // Update task metrics if context is not null // TODO: Make context non optional in a future release Option(context).foreach { c => - c.taskMetrics.memoryBytesSpilled += combiners.memoryBytesSpilled - c.taskMetrics.diskBytesSpilled += combiners.diskBytesSpilled + c.taskMetrics.incMemoryBytesSpilled(combiners.memoryBytesSpilled) + c.taskMetrics.incDiskBytesSpilled(combiners.diskBytesSpilled) } combiners.iterator } @@ -95,8 +95,8 @@ case class Aggregator[K, V, C] ( // Update task metrics if context is not null // TODO: Make context non-optional in a future release Option(context).foreach { c => - c.taskMetrics.memoryBytesSpilled += combiners.memoryBytesSpilled - c.taskMetrics.diskBytesSpilled += combiners.diskBytesSpilled + c.taskMetrics.incMemoryBytesSpilled(combiners.memoryBytesSpilled) + c.taskMetrics.incDiskBytesSpilled(combiners.diskBytesSpilled) } combiners.iterator } diff --git a/core/src/main/scala/org/apache/spark/CacheManager.scala b/core/src/main/scala/org/apache/spark/CacheManager.scala index a0c0372b7f0ef..a96d754744a05 100644 --- a/core/src/main/scala/org/apache/spark/CacheManager.scala +++ b/core/src/main/scala/org/apache/spark/CacheManager.scala @@ -47,10 +47,15 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging { val inputMetrics = blockResult.inputMetrics val existingMetrics = context.taskMetrics .getInputMetricsForReadMethod(inputMetrics.readMethod) - existingMetrics.addBytesRead(inputMetrics.bytesRead) - - new InterruptibleIterator(context, blockResult.data.asInstanceOf[Iterator[T]]) + existingMetrics.incBytesRead(inputMetrics.bytesRead) + val iter = blockResult.data.asInstanceOf[Iterator[T]] + new InterruptibleIterator[T](context, iter) { + override def next(): T = { + existingMetrics.incRecordsRead(1) + delegate.next() + } + } case None => // Acquire a lock for loading this partition // If another thread already holds the lock, wait for it to finish return its results diff --git a/core/src/main/scala/org/apache/spark/ContextCleaner.scala b/core/src/main/scala/org/apache/spark/ContextCleaner.scala index ede1e23f4fcc5..434f1e47cf822 100644 --- a/core/src/main/scala/org/apache/spark/ContextCleaner.scala +++ b/core/src/main/scala/org/apache/spark/ContextCleaner.scala @@ -32,6 +32,7 @@ private sealed trait CleanupTask private case class CleanRDD(rddId: Int) extends CleanupTask private case class CleanShuffle(shuffleId: Int) extends CleanupTask private case class CleanBroadcast(broadcastId: Long) extends CleanupTask +private case class CleanAccum(accId: Long) extends CleanupTask /** * A WeakReference associated with a CleanupTask. @@ -114,6 +115,10 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { registerForCleanup(rdd, CleanRDD(rdd.id)) } + def registerAccumulatorForCleanup(a: Accumulable[_, _]): Unit = { + registerForCleanup(a, CleanAccum(a.id)) + } + /** Register a ShuffleDependency for cleanup when it is garbage collected. */ def registerShuffleForCleanup(shuffleDependency: ShuffleDependency[_, _, _]) { registerForCleanup(shuffleDependency, CleanShuffle(shuffleDependency.shuffleId)) @@ -145,6 +150,8 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { doCleanupShuffle(shuffleId, blocking = blockOnShuffleCleanupTasks) case CleanBroadcast(broadcastId) => doCleanupBroadcast(broadcastId, blocking = blockOnCleanupTasks) + case CleanAccum(accId) => + doCleanupAccum(accId, blocking = blockOnCleanupTasks) } } } catch { @@ -190,6 +197,18 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { } } + /** Perform accumulator cleanup. */ + def doCleanupAccum(accId: Long, blocking: Boolean) { + try { + logDebug("Cleaning accumulator " + accId) + Accumulators.remove(accId) + listeners.foreach(_.accumCleaned(accId)) + logInfo("Cleaned accumulator " + accId) + } catch { + case e: Exception => logError("Error cleaning accumulator " + accId, e) + } + } + private def blockManagerMaster = sc.env.blockManager.master private def broadcastManager = sc.env.broadcastManager private def mapOutputTrackerMaster = sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster] @@ -206,4 +225,5 @@ private[spark] trait CleanerListener { def rddCleaned(rddId: Int) def shuffleCleaned(shuffleId: Int) def broadcastCleaned(broadcastId: Long) + def accumCleaned(accId: Long) } diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationClient.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationClient.scala index a46a81eabd965..443830f8d03b6 100644 --- a/core/src/main/scala/org/apache/spark/ExecutorAllocationClient.scala +++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationClient.scala @@ -19,24 +19,32 @@ package org.apache.spark /** * A client that communicates with the cluster manager to request or kill executors. + * This is currently supported only in YARN mode. */ private[spark] trait ExecutorAllocationClient { + /** + * Express a preference to the cluster manager for a given total number of executors. + * This can result in canceling pending requests or filing additional requests. + * @return whether the request is acknowledged by the cluster manager. + */ + private[spark] def requestTotalExecutors(numExecutors: Int): Boolean + /** * Request an additional number of executors from the cluster manager. - * Return whether the request is acknowledged by the cluster manager. + * @return whether the request is acknowledged by the cluster manager. */ def requestExecutors(numAdditionalExecutors: Int): Boolean /** * Request that the cluster manager kill the specified executors. - * Return whether the request is acknowledged by the cluster manager. + * @return whether the request is acknowledged by the cluster manager. */ def killExecutors(executorIds: Seq[String]): Boolean /** * Request that the cluster manager kill the specified executor. - * Return whether the request is acknowledged by the cluster manager. + * @return whether the request is acknowledged by the cluster manager. */ def killExecutor(executorId: String): Boolean = killExecutors(Seq(executorId)) } diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala index a0ee2a7cbb2a2..21c6e6ffa6666 100644 --- a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala +++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala @@ -20,6 +20,7 @@ package org.apache.spark import scala.collection.mutable import org.apache.spark.scheduler._ +import org.apache.spark.util.{SystemClock, Clock} /** * An agent that dynamically allocates and removes executors based on the workload. @@ -49,6 +50,7 @@ import org.apache.spark.scheduler._ * spark.dynamicAllocation.enabled - Whether this feature is enabled * spark.dynamicAllocation.minExecutors - Lower bound on the number of executors * spark.dynamicAllocation.maxExecutors - Upper bound on the number of executors + * spark.dynamicAllocation.initialExecutors - Number of executors to start with * * spark.dynamicAllocation.schedulerBacklogTimeout (M) - * If there are backlogged tasks for this duration, add new executors @@ -70,19 +72,20 @@ private[spark] class ExecutorAllocationManager( import ExecutorAllocationManager._ - // Lower and upper bounds on the number of executors. These are required. - private val minNumExecutors = conf.getInt("spark.dynamicAllocation.minExecutors", -1) - private val maxNumExecutors = conf.getInt("spark.dynamicAllocation.maxExecutors", -1) + // Lower and upper bounds on the number of executors. + private val minNumExecutors = conf.getInt("spark.dynamicAllocation.minExecutors", 0) + private val maxNumExecutors = conf.getInt("spark.dynamicAllocation.maxExecutors", + Integer.MAX_VALUE) - // How long there must be backlogged tasks for before an addition is triggered + // How long there must be backlogged tasks for before an addition is triggered (seconds) private val schedulerBacklogTimeout = conf.getLong( - "spark.dynamicAllocation.schedulerBacklogTimeout", 60) + "spark.dynamicAllocation.schedulerBacklogTimeout", 5) // Same as above, but used only after `schedulerBacklogTimeout` is exceeded private val sustainedSchedulerBacklogTimeout = conf.getLong( "spark.dynamicAllocation.sustainedSchedulerBacklogTimeout", schedulerBacklogTimeout) - // How long an executor must be idle for before it is removed + // How long an executor must be idle for before it is removed (seconds) private val executorIdleTimeout = conf.getLong( "spark.dynamicAllocation.executorIdleTimeout", 600) @@ -121,7 +124,7 @@ private[spark] class ExecutorAllocationManager( private val intervalMillis: Long = 100 // Clock used to schedule when executors should be added and removed - private var clock: Clock = new RealClock + private var clock: Clock = new SystemClock() // Listener for Spark events that impact the allocation policy private val listener = new ExecutorAllocationListener @@ -132,10 +135,10 @@ private[spark] class ExecutorAllocationManager( */ private def validateSettings(): Unit = { if (minNumExecutors < 0 || maxNumExecutors < 0) { - throw new SparkException("spark.dynamicAllocation.{min/max}Executors must be set!") + throw new SparkException("spark.dynamicAllocation.{min/max}Executors must be positive!") } - if (minNumExecutors == 0 || maxNumExecutors == 0) { - throw new SparkException("spark.dynamicAllocation.{min/max}Executors cannot be 0!") + if (maxNumExecutors == 0) { + throw new SparkException("spark.dynamicAllocation.maxExecutors cannot be 0!") } if (minNumExecutors > maxNumExecutors) { throw new SparkException(s"spark.dynamicAllocation.minExecutors ($minNumExecutors) must " + @@ -158,7 +161,7 @@ private[spark] class ExecutorAllocationManager( "shuffle service. You may enable this through spark.shuffle.service.enabled.") } if (tasksPerExecutor == 0) { - throw new SparkException("spark.executor.cores must not be less than spark.task.cpus.cores") + throw new SparkException("spark.executor.cores must not be less than spark.task.cpus.") } } @@ -199,18 +202,34 @@ private[spark] class ExecutorAllocationManager( } /** - * If the add time has expired, request new executors and refresh the add time. - * If the remove time for an existing executor has expired, kill the executor. + * The number of executors we would have if the cluster manager were to fulfill all our existing + * requests. + */ + private def targetNumExecutors(): Int = + numExecutorsPending + executorIds.size - executorsPendingToRemove.size + + /** + * The maximum number of executors we would need under the current load to satisfy all running + * and pending tasks, rounded up. + */ + private def maxNumExecutorsNeeded(): Int = { + val numRunningOrPendingTasks = listener.totalPendingTasks + listener.totalRunningTasks + (numRunningOrPendingTasks + tasksPerExecutor - 1) / tasksPerExecutor + } + + /** + * This is called at a fixed interval to regulate the number of pending executor requests + * and number of executors running. + * + * First, adjust our requested executors based on the add time and our current needs. + * Then, if the remove time for an existing executor has expired, kill the executor. + * * This is factored out into its own method for testing. */ private def schedule(): Unit = synchronized { val now = clock.getTimeMillis - if (addTime != NOT_SET && now >= addTime) { - addExecutors() - logDebug(s"Starting timer to add more executors (to " + - s"expire in $sustainedSchedulerBacklogTimeout seconds)") - addTime += sustainedSchedulerBacklogTimeout * 1000 - } + + addOrCancelExecutorRequests(now) removeTimes.retain { case (executorId, expireTime) => val expired = now >= expireTime @@ -221,59 +240,89 @@ private[spark] class ExecutorAllocationManager( } } + /** + * Check to see whether our existing allocation and the requests we've made previously exceed our + * current needs. If so, let the cluster manager know so that it can cancel pending requests that + * are unneeded. + * + * If not, and the add time has expired, see if we can request new executors and refresh the add + * time. + * + * @return the delta in the target number of executors. + */ + private def addOrCancelExecutorRequests(now: Long): Int = synchronized { + val currentTarget = targetNumExecutors + val maxNeeded = maxNumExecutorsNeeded + + if (maxNeeded < currentTarget) { + // The target number exceeds the number we actually need, so stop adding new + // executors and inform the cluster manager to cancel the extra pending requests. + val newTotalExecutors = math.max(maxNeeded, minNumExecutors) + client.requestTotalExecutors(newTotalExecutors) + numExecutorsToAdd = 1 + updateNumExecutorsPending(newTotalExecutors) + } else if (addTime != NOT_SET && now >= addTime) { + val delta = addExecutors(maxNeeded) + logDebug(s"Starting timer to add more executors (to " + + s"expire in $sustainedSchedulerBacklogTimeout seconds)") + addTime += sustainedSchedulerBacklogTimeout * 1000 + delta + } else { + 0 + } + } + /** * Request a number of executors from the cluster manager. * If the cap on the number of executors is reached, give up and reset the * number of executors to add next round instead of continuing to double it. - * Return the number actually requested. + * + * @param maxNumExecutorsNeeded the maximum number of executors all currently running or pending + * tasks could fill + * @return the number of additional executors actually requested. */ - private def addExecutors(): Int = synchronized { - // Do not request more executors if we have already reached the upper bound - val numExistingExecutors = executorIds.size + numExecutorsPending - if (numExistingExecutors >= maxNumExecutors) { + private def addExecutors(maxNumExecutorsNeeded: Int): Int = { + // Do not request more executors if it would put our target over the upper bound + val currentTarget = targetNumExecutors + if (currentTarget >= maxNumExecutors) { logDebug(s"Not adding executors because there are already ${executorIds.size} " + s"registered and $numExecutorsPending pending executor(s) (limit $maxNumExecutors)") numExecutorsToAdd = 1 return 0 } - // The number of executors needed to satisfy all pending tasks is the number of tasks pending - // divided by the number of tasks each executor can fit, rounded up. - val maxNumExecutorsPending = - (listener.totalPendingTasks() + tasksPerExecutor - 1) / tasksPerExecutor - if (numExecutorsPending >= maxNumExecutorsPending) { - logDebug(s"Not adding executors because there are already $numExecutorsPending " + - s"pending and pending tasks could only fill $maxNumExecutorsPending") - numExecutorsToAdd = 1 - return 0 - } - - // It's never useful to request more executors than could satisfy all the pending tasks, so - // cap request at that amount. - // Also cap request with respect to the configured upper bound. - val maxNumExecutorsToAdd = math.min( - maxNumExecutorsPending - numExecutorsPending, - maxNumExecutors - numExistingExecutors) - assert(maxNumExecutorsToAdd > 0) - - val actualNumExecutorsToAdd = math.min(numExecutorsToAdd, maxNumExecutorsToAdd) - - val newTotalExecutors = numExistingExecutors + actualNumExecutorsToAdd - val addRequestAcknowledged = testing || client.requestExecutors(actualNumExecutorsToAdd) + val actualMaxNumExecutors = math.min(maxNumExecutors, maxNumExecutorsNeeded) + val newTotalExecutors = math.min(currentTarget + numExecutorsToAdd, actualMaxNumExecutors) + val addRequestAcknowledged = testing || client.requestTotalExecutors(newTotalExecutors) if (addRequestAcknowledged) { - logInfo(s"Requesting $actualNumExecutorsToAdd new executor(s) because " + - s"tasks are backlogged (new desired total will be $newTotalExecutors)") - numExecutorsToAdd = - if (actualNumExecutorsToAdd == numExecutorsToAdd) numExecutorsToAdd * 2 else 1 - numExecutorsPending += actualNumExecutorsToAdd - actualNumExecutorsToAdd + val delta = updateNumExecutorsPending(newTotalExecutors) + logInfo(s"Requesting $delta new executor(s) because tasks are backlogged" + + s" (new desired total will be $newTotalExecutors)") + numExecutorsToAdd = if (delta == numExecutorsToAdd) { + numExecutorsToAdd * 2 + } else { + 1 + } + delta } else { - logWarning(s"Unable to reach the cluster manager " + - s"to request $actualNumExecutorsToAdd executors!") + logWarning( + s"Unable to reach the cluster manager to request $newTotalExecutors total executors!") 0 } } + /** + * Given the new target number of executors, update the number of pending executor requests, + * and return the delta from the old number of pending requests. + */ + private def updateNumExecutorsPending(newTotalExecutors: Int): Int = { + val newNumExecutorsPending = + newTotalExecutors - executorIds.size + executorsPendingToRemove.size + val delta = newNumExecutorsPending - numExecutorsPending + numExecutorsPending = newNumExecutorsPending + delta + } + /** * Request the cluster manager to remove the given executor. * Return whether the request is received. @@ -413,6 +462,8 @@ private[spark] class ExecutorAllocationManager( private val stageIdToNumTasks = new mutable.HashMap[Int, Int] private val stageIdToTaskIndices = new mutable.HashMap[Int, mutable.HashSet[Int]] private val executorIdToTaskIds = new mutable.HashMap[String, mutable.HashSet[Long]] + // Number of tasks currently running on the cluster. Should be 0 when no stages are active. + private var numRunningTasks: Int = _ override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted): Unit = { val stageId = stageSubmitted.stageInfo.stageId @@ -433,6 +484,10 @@ private[spark] class ExecutorAllocationManager( // This is needed in case the stage is aborted for any reason if (stageIdToNumTasks.isEmpty) { allocationManager.onSchedulerQueueEmpty() + if (numRunningTasks != 0) { + logWarning("No stages are running, but numRunningTasks != 0") + numRunningTasks = 0 + } } } } @@ -444,6 +499,7 @@ private[spark] class ExecutorAllocationManager( val executorId = taskStart.taskInfo.executorId allocationManager.synchronized { + numRunningTasks += 1 // This guards against the race condition in which the `SparkListenerTaskStart` // event is posted before the `SparkListenerBlockManagerAdded` event, which is // possible because these events are posted in different threads. (see SPARK-4951) @@ -473,7 +529,8 @@ private[spark] class ExecutorAllocationManager( val executorId = taskEnd.taskInfo.executorId val taskId = taskEnd.taskInfo.taskId allocationManager.synchronized { - // If the executor is no longer running scheduled any tasks, mark it as idle + numRunningTasks -= 1 + // If the executor is no longer running any scheduled tasks, mark it as idle if (executorIdToTaskIds.contains(executorId)) { executorIdToTaskIds(executorId) -= taskId if (executorIdToTaskIds(executorId).isEmpty) { @@ -484,8 +541,8 @@ private[spark] class ExecutorAllocationManager( } } - override def onBlockManagerAdded(blockManagerAdded: SparkListenerBlockManagerAdded): Unit = { - val executorId = blockManagerAdded.blockManagerId.executorId + override def onExecutorAdded(executorAdded: SparkListenerExecutorAdded): Unit = { + val executorId = executorAdded.executorId if (executorId != SparkContext.DRIVER_IDENTIFIER) { // This guards against the race condition in which the `SparkListenerTaskStart` // event is posted before the `SparkListenerBlockManagerAdded` event, which is @@ -496,9 +553,8 @@ private[spark] class ExecutorAllocationManager( } } - override def onBlockManagerRemoved( - blockManagerRemoved: SparkListenerBlockManagerRemoved): Unit = { - allocationManager.onExecutorRemoved(blockManagerRemoved.blockManagerId.executorId) + override def onExecutorRemoved(executorRemoved: SparkListenerExecutorRemoved): Unit = { + allocationManager.onExecutorRemoved(executorRemoved.executorId) } /** @@ -513,6 +569,11 @@ private[spark] class ExecutorAllocationManager( }.sum } + /** + * The number of tasks currently running across all stages. + */ + def totalRunningTasks(): Int = numRunningTasks + /** * Return true if an executor is not currently running a task, and false otherwise. * @@ -528,28 +589,3 @@ private[spark] class ExecutorAllocationManager( private object ExecutorAllocationManager { val NOT_SET = Long.MaxValue } - -/** - * An abstract clock for measuring elapsed time. - */ -private trait Clock { - def getTimeMillis: Long -} - -/** - * A clock backed by a monotonically increasing time source. - * The time returned by this clock does not correspond to any notion of wall-clock time. - */ -private class RealClock extends Clock { - override def getTimeMillis: Long = System.nanoTime / (1000 * 1000) -} - -/** - * A clock that allows the caller to customize the time. - * This is used mainly for testing. - */ -private class TestClock(startTimeMillis: Long) extends Clock { - private var time: Long = startTimeMillis - override def getTimeMillis: Long = time - def tick(ms: Long): Unit = { time += ms } -} diff --git a/core/src/main/scala/org/apache/spark/HttpFileServer.scala b/core/src/main/scala/org/apache/spark/HttpFileServer.scala index 677c5e0f89d72..7e706bcc42f04 100644 --- a/core/src/main/scala/org/apache/spark/HttpFileServer.scala +++ b/core/src/main/scala/org/apache/spark/HttpFileServer.scala @@ -36,7 +36,7 @@ private[spark] class HttpFileServer( var serverUri : String = null def initialize() { - baseDir = Utils.createTempDir() + baseDir = Utils.createTempDir(Utils.getLocalDir(conf), "httpd") fileDir = new File(baseDir, "files") jarDir = new File(baseDir, "jars") fileDir.mkdir() @@ -50,6 +50,15 @@ private[spark] class HttpFileServer( def stop() { httpServer.stop() + + // If we only stop sc, but the driver process still run as a services then we need to delete + // the tmp dir, if not, it will create too many tmp dirs + try { + Utils.deleteRecursively(baseDir) + } catch { + case e: Exception => + logWarning(s"Exception while deleting Spark temp dir: ${baseDir.getAbsolutePath}", e) + } } def addFile(file: File) : String = { diff --git a/core/src/main/scala/org/apache/spark/HttpServer.scala b/core/src/main/scala/org/apache/spark/HttpServer.scala index fa22787ce7ea3..09a9ccc226721 100644 --- a/core/src/main/scala/org/apache/spark/HttpServer.scala +++ b/core/src/main/scala/org/apache/spark/HttpServer.scala @@ -19,6 +19,7 @@ package org.apache.spark import java.io.File +import org.eclipse.jetty.server.ssl.SslSocketConnector import org.eclipse.jetty.util.security.{Constraint, Password} import org.eclipse.jetty.security.authentication.DigestAuthenticator import org.eclipse.jetty.security.{ConstraintMapping, ConstraintSecurityHandler, HashLoginService} @@ -72,7 +73,10 @@ private[spark] class HttpServer( */ private def doStart(startPort: Int): (Server, Int) = { val server = new Server() - val connector = new SocketConnector + + val connector = securityManager.fileServerSSLOptions.createJettySslContextFactory() + .map(new SslSocketConnector(_)).getOrElse(new SocketConnector) + connector.setMaxIdleTime(60 * 1000) connector.setSoLingerTime(-1) connector.setPort(startPort) @@ -149,13 +153,14 @@ private[spark] class HttpServer( } /** - * Get the URI of this HTTP server (http://host:port) + * Get the URI of this HTTP server (http://host:port or https://host:port) */ def uri: String = { if (server == null) { throw new ServerStateException("Server is not started") } else { - "http://" + Utils.localIpAddress + ":" + port + val scheme = if (securityManager.fileServerSSLOptions.enabled) "https" else "http" + s"$scheme://${Utils.localIpAddress}:$port" } } } diff --git a/core/src/main/scala/org/apache/spark/Logging.scala b/core/src/main/scala/org/apache/spark/Logging.scala index d4f2624061e35..419d093d55643 100644 --- a/core/src/main/scala/org/apache/spark/Logging.scala +++ b/core/src/main/scala/org/apache/spark/Logging.scala @@ -118,15 +118,17 @@ trait Logging { // org.slf4j.impl.Log4jLoggerFactory, from the log4j 2.0 binding, currently // org.apache.logging.slf4j.Log4jLoggerFactory val usingLog4j12 = "org.slf4j.impl.Log4jLoggerFactory".equals(binderClass) - val log4j12Initialized = LogManager.getRootLogger.getAllAppenders.hasMoreElements - if (!log4j12Initialized && usingLog4j12) { - val defaultLogProps = "org/apache/spark/log4j-defaults.properties" - Option(Utils.getSparkClassLoader.getResource(defaultLogProps)) match { - case Some(url) => - PropertyConfigurator.configure(url) - System.err.println(s"Using Spark's default log4j profile: $defaultLogProps") - case None => - System.err.println(s"Spark was unable to load $defaultLogProps") + if (usingLog4j12) { + val log4j12Initialized = LogManager.getRootLogger.getAllAppenders.hasMoreElements + if (!log4j12Initialized) { + val defaultLogProps = "org/apache/spark/log4j-defaults.properties" + Option(Utils.getSparkClassLoader.getResource(defaultLogProps)) match { + case Some(url) => + PropertyConfigurator.configure(url) + System.err.println(s"Using Spark's default log4j profile: $defaultLogProps") + case None => + System.err.println(s"Spark was unable to load $defaultLogProps") + } } } Logging.initialized = true diff --git a/core/src/main/scala/org/apache/spark/SSLOptions.scala b/core/src/main/scala/org/apache/spark/SSLOptions.scala new file mode 100644 index 0000000000000..2cdc167f85af0 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/SSLOptions.scala @@ -0,0 +1,178 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import java.io.File + +import com.typesafe.config.{Config, ConfigFactory, ConfigValueFactory} +import org.eclipse.jetty.util.ssl.SslContextFactory + +/** + * SSLOptions class is a common container for SSL configuration options. It offers methods to + * generate specific objects to configure SSL for different communication protocols. + * + * SSLOptions is intended to provide the maximum common set of SSL settings, which are supported + * by the protocol, which it can generate the configuration for. Since Akka doesn't support client + * authentication with SSL, SSLOptions cannot support it either. + * + * @param enabled enables or disables SSL; if it is set to false, the rest of the + * settings are disregarded + * @param keyStore a path to the key-store file + * @param keyStorePassword a password to access the key-store file + * @param keyPassword a password to access the private key in the key-store + * @param trustStore a path to the trust-store file + * @param trustStorePassword a password to access the trust-store file + * @param protocol SSL protocol (remember that SSLv3 was compromised) supported by Java + * @param enabledAlgorithms a set of encryption algorithms to use + */ +private[spark] case class SSLOptions( + enabled: Boolean = false, + keyStore: Option[File] = None, + keyStorePassword: Option[String] = None, + keyPassword: Option[String] = None, + trustStore: Option[File] = None, + trustStorePassword: Option[String] = None, + protocol: Option[String] = None, + enabledAlgorithms: Set[String] = Set.empty) { + + /** + * Creates a Jetty SSL context factory according to the SSL settings represented by this object. + */ + def createJettySslContextFactory(): Option[SslContextFactory] = { + if (enabled) { + val sslContextFactory = new SslContextFactory() + + keyStore.foreach(file => sslContextFactory.setKeyStorePath(file.getAbsolutePath)) + trustStore.foreach(file => sslContextFactory.setTrustStore(file.getAbsolutePath)) + keyStorePassword.foreach(sslContextFactory.setKeyStorePassword) + trustStorePassword.foreach(sslContextFactory.setTrustStorePassword) + keyPassword.foreach(sslContextFactory.setKeyManagerPassword) + protocol.foreach(sslContextFactory.setProtocol) + sslContextFactory.setIncludeCipherSuites(enabledAlgorithms.toSeq: _*) + + Some(sslContextFactory) + } else { + None + } + } + + /** + * Creates an Akka configuration object which contains all the SSL settings represented by this + * object. It can be used then to compose the ultimate Akka configuration. + */ + def createAkkaConfig: Option[Config] = { + import scala.collection.JavaConversions._ + if (enabled) { + Some(ConfigFactory.empty() + .withValue("akka.remote.netty.tcp.security.key-store", + ConfigValueFactory.fromAnyRef(keyStore.map(_.getAbsolutePath).getOrElse(""))) + .withValue("akka.remote.netty.tcp.security.key-store-password", + ConfigValueFactory.fromAnyRef(keyStorePassword.getOrElse(""))) + .withValue("akka.remote.netty.tcp.security.trust-store", + ConfigValueFactory.fromAnyRef(trustStore.map(_.getAbsolutePath).getOrElse(""))) + .withValue("akka.remote.netty.tcp.security.trust-store-password", + ConfigValueFactory.fromAnyRef(trustStorePassword.getOrElse(""))) + .withValue("akka.remote.netty.tcp.security.key-password", + ConfigValueFactory.fromAnyRef(keyPassword.getOrElse(""))) + .withValue("akka.remote.netty.tcp.security.random-number-generator", + ConfigValueFactory.fromAnyRef("")) + .withValue("akka.remote.netty.tcp.security.protocol", + ConfigValueFactory.fromAnyRef(protocol.getOrElse(""))) + .withValue("akka.remote.netty.tcp.security.enabled-algorithms", + ConfigValueFactory.fromIterable(enabledAlgorithms.toSeq)) + .withValue("akka.remote.netty.tcp.enable-ssl", + ConfigValueFactory.fromAnyRef(true))) + } else { + None + } + } + + /** Returns a string representation of this SSLOptions with all the passwords masked. */ + override def toString: String = s"SSLOptions{enabled=$enabled, " + + s"keyStore=$keyStore, keyStorePassword=${keyStorePassword.map(_ => "xxx")}, " + + s"trustStore=$trustStore, trustStorePassword=${trustStorePassword.map(_ => "xxx")}, " + + s"protocol=$protocol, enabledAlgorithms=$enabledAlgorithms}" + +} + +private[spark] object SSLOptions extends Logging { + + /** Resolves SSLOptions settings from a given Spark configuration object at a given namespace. + * + * The following settings are allowed: + * $ - `[ns].enabled` - `true` or `false`, to enable or disable SSL respectively + * $ - `[ns].keyStore` - a path to the key-store file; can be relative to the current directory + * $ - `[ns].keyStorePassword` - a password to the key-store file + * $ - `[ns].keyPassword` - a password to the private key + * $ - `[ns].trustStore` - a path to the trust-store file; can be relative to the current + * directory + * $ - `[ns].trustStorePassword` - a password to the trust-store file + * $ - `[ns].protocol` - a protocol name supported by a particular Java version + * $ - `[ns].enabledAlgorithms` - a comma separated list of ciphers + * + * For a list of protocols and ciphers supported by particular Java versions, you may go to + * [[https://blogs.oracle.com/java-platform-group/entry/diagnosing_tls_ssl_and_https Oracle + * blog page]]. + * + * You can optionally specify the default configuration. If you do, for each setting which is + * missing in SparkConf, the corresponding setting is used from the default configuration. + * + * @param conf Spark configuration object where the settings are collected from + * @param ns the namespace name + * @param defaults the default configuration + * @return [[org.apache.spark.SSLOptions]] object + */ + def parse(conf: SparkConf, ns: String, defaults: Option[SSLOptions] = None): SSLOptions = { + val enabled = conf.getBoolean(s"$ns.enabled", defaultValue = defaults.exists(_.enabled)) + + val keyStore = conf.getOption(s"$ns.keyStore").map(new File(_)) + .orElse(defaults.flatMap(_.keyStore)) + + val keyStorePassword = conf.getOption(s"$ns.keyStorePassword") + .orElse(defaults.flatMap(_.keyStorePassword)) + + val keyPassword = conf.getOption(s"$ns.keyPassword") + .orElse(defaults.flatMap(_.keyPassword)) + + val trustStore = conf.getOption(s"$ns.trustStore").map(new File(_)) + .orElse(defaults.flatMap(_.trustStore)) + + val trustStorePassword = conf.getOption(s"$ns.trustStorePassword") + .orElse(defaults.flatMap(_.trustStorePassword)) + + val protocol = conf.getOption(s"$ns.protocol") + .orElse(defaults.flatMap(_.protocol)) + + val enabledAlgorithms = conf.getOption(s"$ns.enabledAlgorithms") + .map(_.split(",").map(_.trim).filter(_.nonEmpty).toSet) + .orElse(defaults.map(_.enabledAlgorithms)) + .getOrElse(Set.empty) + + new SSLOptions( + enabled, + keyStore, + keyStorePassword, + keyPassword, + trustStore, + trustStorePassword, + protocol, + enabledAlgorithms) + } + +} + diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala index ec82d09cd079b..3653f724ba192 100644 --- a/core/src/main/scala/org/apache/spark/SecurityManager.scala +++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala @@ -18,11 +18,16 @@ package org.apache.spark import java.net.{Authenticator, PasswordAuthentication} +import java.security.KeyStore +import java.security.cert.X509Certificate +import javax.net.ssl._ +import com.google.common.io.Files import org.apache.hadoop.io.Text import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.network.sasl.SecretKeyHolder +import org.apache.spark.util.Utils /** * Spark class responsible for security. @@ -55,7 +60,7 @@ import org.apache.spark.network.sasl.SecretKeyHolder * Spark also has a set of admin acls (`spark.admin.acls`) which is a set of users/administrators * who always have permission to view or modify the Spark application. * - * Spark does not currently support encryption after authentication. + * Starting from version 1.3, Spark has partial support for encrypted connections with SSL. * * At this point spark has multiple communication protocols that need to be secured and * different underlying mechanisms are used depending on the protocol: @@ -67,8 +72,9 @@ import org.apache.spark.network.sasl.SecretKeyHolder * to connect to the server. There is no control of the underlying * authentication mechanism so its not clear if the password is passed in * plaintext or uses DIGEST-MD5 or some other mechanism. - * Akka also has an option to turn on SSL, this option is not currently supported - * but we could add a configuration option in the future. + * + * Akka also has an option to turn on SSL, this option is currently supported (see + * the details below). * * - HTTP for broadcast and file server (via HttpServer) -> Spark currently uses Jetty * for the HttpServer. Jetty supports multiple authentication mechanisms - @@ -77,8 +83,9 @@ import org.apache.spark.network.sasl.SecretKeyHolder * to authenticate using DIGEST-MD5 via a single user and the shared secret. * Since we are using DIGEST-MD5, the shared secret is not passed on the wire * in plaintext. - * We currently do not support SSL (https), but Jetty can be configured to use it - * so we could add a configuration option for this in the future. + * + * We currently support SSL (https) for this communication protocol (see the details + * below). * * The Spark HttpServer installs the HashLoginServer and configures it to DIGEST-MD5. * Any clients must specify the user and password. There is a default @@ -142,9 +149,40 @@ import org.apache.spark.network.sasl.SecretKeyHolder * authentication. Spark will then use that user to compare against the view acls to do * authorization. If not filter is in place the user is generally null and no authorization * can take place. + * + * Connection encryption (SSL) configuration is organized hierarchically. The user can configure + * the default SSL settings which will be used for all the supported communication protocols unless + * they are overwritten by protocol specific settings. This way the user can easily provide the + * common settings for all the protocols without disabling the ability to configure each one + * individually. + * + * All the SSL settings like `spark.ssl.xxx` where `xxx` is a particular configuration property, + * denote the global configuration for all the supported protocols. In order to override the global + * configuration for the particular protocol, the properties must be overwritten in the + * protocol-specific namespace. Use `spark.ssl.yyy.xxx` settings to overwrite the global + * configuration for particular protocol denoted by `yyy`. Currently `yyy` can be either `akka` for + * Akka based connections or `fs` for broadcast and file server. + * + * Refer to [[org.apache.spark.SSLOptions]] documentation for the list of + * options that can be specified. + * + * SecurityManager initializes SSLOptions objects for different protocols separately. SSLOptions + * object parses Spark configuration at a given namespace and builds the common representation + * of SSL settings. SSLOptions is then used to provide protocol-specific configuration like + * TypeSafe configuration for Akka or SSLContextFactory for Jetty. + * + * SSL must be configured on each node and configured for each component involved in + * communication using the particular protocol. In YARN clusters, the key-store can be prepared on + * the client side then distributed and used by the executors as the part of the application + * (YARN allows the user to deploy files before the application is started). + * In standalone deployment, the user needs to provide key-stores and configuration + * options for master and workers. In this mode, the user may allow the executors to use the SSL + * settings inherited from the worker which spawned that executor. It can be accomplished by + * setting `spark.ssl.useNodeLocalConf` to `true`. */ -private[spark] class SecurityManager(sparkConf: SparkConf) extends Logging with SecretKeyHolder { +private[spark] class SecurityManager(sparkConf: SparkConf) + extends Logging with SecretKeyHolder { // key used to store the spark secret in the Hadoop UGI private val sparkSecretLookupKey = "sparkCookie" @@ -166,7 +204,7 @@ private[spark] class SecurityManager(sparkConf: SparkConf) extends Logging with // always add the current user and SPARK_USER to the viewAcls private val defaultAclUsers = Set[String](System.getProperty("user.name", ""), - Option(System.getenv("SPARK_USER")).getOrElse("")).filter(!_.isEmpty) + Utils.getCurrentUserName()) setViewAcls(defaultAclUsers, sparkConf.get("spark.ui.view.acls", "")) setModifyAcls(defaultAclUsers, sparkConf.get("spark.modify.acls", "")) @@ -196,6 +234,57 @@ private[spark] class SecurityManager(sparkConf: SparkConf) extends Logging with ) } + // the default SSL configuration - it will be used by all communication layers unless overwritten + private val defaultSSLOptions = SSLOptions.parse(sparkConf, "spark.ssl", defaults = None) + + // SSL configuration for different communication layers - they can override the default + // configuration at a specified namespace. The namespace *must* start with spark.ssl. + val fileServerSSLOptions = SSLOptions.parse(sparkConf, "spark.ssl.fs", Some(defaultSSLOptions)) + val akkaSSLOptions = SSLOptions.parse(sparkConf, "spark.ssl.akka", Some(defaultSSLOptions)) + + logDebug(s"SSLConfiguration for file server: $fileServerSSLOptions") + logDebug(s"SSLConfiguration for Akka: $akkaSSLOptions") + + val (sslSocketFactory, hostnameVerifier) = if (fileServerSSLOptions.enabled) { + val trustStoreManagers = + for (trustStore <- fileServerSSLOptions.trustStore) yield { + val input = Files.asByteSource(fileServerSSLOptions.trustStore.get).openStream() + + try { + val ks = KeyStore.getInstance(KeyStore.getDefaultType) + ks.load(input, fileServerSSLOptions.trustStorePassword.get.toCharArray) + + val tmf = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm) + tmf.init(ks) + tmf.getTrustManagers + } finally { + input.close() + } + } + + lazy val credulousTrustStoreManagers = Array({ + logWarning("Using 'accept-all' trust manager for SSL connections.") + new X509TrustManager { + override def getAcceptedIssuers: Array[X509Certificate] = null + + override def checkClientTrusted(x509Certificates: Array[X509Certificate], s: String) {} + + override def checkServerTrusted(x509Certificates: Array[X509Certificate], s: String) {} + }: TrustManager + }) + + val sslContext = SSLContext.getInstance(fileServerSSLOptions.protocol.getOrElse("Default")) + sslContext.init(null, trustStoreManagers.getOrElse(credulousTrustStoreManagers), null) + + val hostVerifier = new HostnameVerifier { + override def verify(s: String, sslSession: SSLSession): Boolean = true + } + + (Some(sslContext.getSocketFactory), Some(hostVerifier)) + } else { + (None, None) + } + /** * Split a comma separated String, filter out any empty items, and return a Set of strings */ diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index a0ce107f43b16..0dbd26146cb13 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -17,9 +17,14 @@ package org.apache.spark +import java.util.concurrent.ConcurrentHashMap +import java.util.concurrent.atomic.AtomicBoolean + import scala.collection.JavaConverters._ -import scala.collection.mutable.{HashMap, LinkedHashSet} +import scala.collection.mutable.LinkedHashSet + import org.apache.spark.serializer.KryoSerializer +import org.apache.spark.util.Utils /** * Configuration for a Spark application. Used to set various Spark parameters as key-value pairs. @@ -46,12 +51,12 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { /** Create a SparkConf that loads defaults from system properties and the classpath */ def this() = this(true) - private[spark] val settings = new HashMap[String, String]() + private val settings = new ConcurrentHashMap[String, String]() if (loadDefaults) { // Load any spark.* system properties - for ((k, v) <- System.getProperties.asScala if k.startsWith("spark.")) { - settings(k) = v + for ((key, value) <- Utils.getSystemProperties if key.startsWith("spark.")) { + set(key, value) } } @@ -63,7 +68,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { if (value == null) { throw new NullPointerException("null value for " + key) } - settings(key) = value + settings.put(translateConfKey(key, warn = true), value) this } @@ -129,15 +134,13 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { /** Set multiple parameters together */ def setAll(settings: Traversable[(String, String)]) = { - this.settings ++= settings + this.settings.putAll(settings.toMap.asJava) this } /** Set a parameter if it isn't already configured */ def setIfMissing(key: String, value: String): SparkConf = { - if (!settings.contains(key)) { - settings(key) = value - } + settings.putIfAbsent(translateConfKey(key, warn = true), value) this } @@ -163,21 +166,23 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { /** Get a parameter; throws a NoSuchElementException if it's not set */ def get(key: String): String = { - settings.getOrElse(key, throw new NoSuchElementException(key)) + getOption(key).getOrElse(throw new NoSuchElementException(key)) } /** Get a parameter, falling back to a default if not set */ def get(key: String, defaultValue: String): String = { - settings.getOrElse(key, defaultValue) + getOption(key).getOrElse(defaultValue) } /** Get a parameter as an Option */ def getOption(key: String): Option[String] = { - settings.get(key) + Option(settings.get(translateConfKey(key))) } /** Get all parameters as a list of pairs */ - def getAll: Array[(String, String)] = settings.clone().toArray + def getAll: Array[(String, String)] = { + settings.entrySet().asScala.map(x => (x.getKey, x.getValue)).toArray + } /** Get a parameter as an integer, falling back to a default if not set */ def getInt(key: String, defaultValue: Int): Int = { @@ -224,11 +229,11 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { def getAppId: String = get("spark.app.id") /** Does the configuration contain a given parameter? */ - def contains(key: String): Boolean = settings.contains(key) + def contains(key: String): Boolean = settings.containsKey(translateConfKey(key)) /** Copy this object */ override def clone: SparkConf = { - new SparkConf(false).setAll(settings) + new SparkConf(false).setAll(getAll) } /** @@ -240,7 +245,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { /** Checks for illegal or deprecated config settings. Throws an exception for the former. Not * idempotent - may mutate this conf object to convert deprecated settings to supported ones. */ private[spark] def validateSettings() { - if (settings.contains("spark.local.dir")) { + if (contains("spark.local.dir")) { val msg = "In Spark 1.0 and later spark.local.dir will be overridden by the value set by " + "the cluster manager (via SPARK_LOCAL_DIRS in mesos/standalone and LOCAL_DIRS in YARN)." logWarning(msg) @@ -265,7 +270,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { } // Validate spark.executor.extraJavaOptions - settings.get(executorOptsKey).map { javaOpts => + getOption(executorOptsKey).map { javaOpts => if (javaOpts.contains("-Dspark")) { val msg = s"$executorOptsKey is not allowed to set Spark options (was '$javaOpts'). " + "Set them directly on a SparkConf or in a properties file when using ./bin/spark-submit." @@ -281,7 +286,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { // Validate memory fractions val memoryKeys = Seq( "spark.storage.memoryFraction", - "spark.shuffle.memoryFraction", + "spark.shuffle.memoryFraction", "spark.shuffle.safetyFraction", "spark.storage.unrollFraction", "spark.storage.safetyFraction") @@ -345,11 +350,22 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { * configuration out for debugging. */ def toDebugString: String = { - settings.toArray.sorted.map{case (k, v) => k + "=" + v}.mkString("\n") + getAll.sorted.map{case (k, v) => k + "=" + v}.mkString("\n") } + } -private[spark] object SparkConf { +private[spark] object SparkConf extends Logging { + + private val deprecatedConfigs: Map[String, DeprecatedConfig] = { + val configs = Seq( + DeprecatedConfig("spark.files.userClassPathFirst", "spark.executor.userClassPathFirst", + "1.3"), + DeprecatedConfig("spark.yarn.user.classpath.first", null, "1.3", + "Use spark.{driver,executor}.userClassPathFirst instead.")) + configs.map { x => (x.oldName, x) }.toMap + } + /** * Return whether the given config is an akka config (e.g. akka.actor.provider). * Note that this does not include spark-specific akka configs (e.g. spark.akka.timeout). @@ -366,6 +382,7 @@ private[spark] object SparkConf { isAkkaConf(name) || name.startsWith("spark.akka") || name.startsWith("spark.auth") || + name.startsWith("spark.ssl") || isSparkPortConf(name) } @@ -375,4 +392,63 @@ private[spark] object SparkConf { def isSparkPortConf(name: String): Boolean = { (name.startsWith("spark.") && name.endsWith(".port")) || name.startsWith("spark.port.") } + + /** + * Translate the configuration key if it is deprecated and has a replacement, otherwise just + * returns the provided key. + * + * @param userKey Configuration key from the user / caller. + * @param warn Whether to print a warning if the key is deprecated. Warnings will be printed + * only once for each key. + */ + def translateConfKey(userKey: String, warn: Boolean = false): String = { + deprecatedConfigs.get(userKey) + .map { deprecatedKey => + if (warn) { + deprecatedKey.warn() + } + deprecatedKey.newName.getOrElse(userKey) + }.getOrElse(userKey) + } + + /** + * Holds information about keys that have been deprecated or renamed. + * + * @param oldName Old configuration key. + * @param newName New configuration key, or `null` if key has no replacement, in which case the + * deprecated key will be used (but the warning message will still be printed). + * @param version Version of Spark where key was deprecated. + * @param deprecationMessage Message to include in the deprecation warning; mandatory when + * `newName` is not provided. + */ + private case class DeprecatedConfig( + oldName: String, + _newName: String, + version: String, + deprecationMessage: String = null) { + + private val warned = new AtomicBoolean(false) + val newName = Option(_newName) + + if (newName == null && (deprecationMessage == null || deprecationMessage.isEmpty())) { + throw new IllegalArgumentException("Need new config name or deprecation message.") + } + + def warn(): Unit = { + if (warned.compareAndSet(false, true)) { + if (newName != null) { + val message = Option(deprecationMessage).getOrElse( + s"Please use the alternative '$newName' instead.") + logWarning( + s"The configuration option '$oldName' has been replaced as of Spark $version and " + + s"may be removed in the future. $message") + } else { + logWarning( + s"The configuration option '$oldName' has been deprecated as of Spark $version and " + + s"may be removed in the future. $deprecationMessage") + } + } + } + + } } diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index ff5d796ee2766..930d4bea4785b 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -20,33 +20,42 @@ package org.apache.spark import scala.language.implicitConversions import java.io._ +import java.lang.reflect.Constructor import java.net.URI import java.util.{Arrays, Properties, UUID} import java.util.concurrent.atomic.AtomicInteger import java.util.UUID.randomUUID + import scala.collection.{Map, Set} import scala.collection.JavaConversions._ import scala.collection.generic.Growable import scala.collection.mutable.HashMap import scala.reflect.{ClassTag, classTag} + +import akka.actor.Props + import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path -import org.apache.hadoop.io.{ArrayWritable, BooleanWritable, BytesWritable, DoubleWritable, FloatWritable, IntWritable, LongWritable, NullWritable, Text, Writable} -import org.apache.hadoop.mapred.{FileInputFormat, InputFormat, JobConf, SequenceFileInputFormat, TextInputFormat} +import org.apache.hadoop.io.{ArrayWritable, BooleanWritable, BytesWritable, DoubleWritable, + FloatWritable, IntWritable, LongWritable, NullWritable, Text, Writable} +import org.apache.hadoop.mapred.{FileInputFormat, InputFormat, JobConf, SequenceFileInputFormat, + TextInputFormat} import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat, Job => NewHadoopJob} import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat => NewFileInputFormat} + import org.apache.mesos.MesosNativeLibrary -import akka.actor.Props import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.{LocalSparkCluster, SparkHadoopUtil} import org.apache.spark.executor.TriggerThreadDump -import org.apache.spark.input.{StreamInputFormat, PortableDataStream, WholeTextFileInputFormat, FixedLengthBinaryInputFormat} +import org.apache.spark.input.{StreamInputFormat, PortableDataStream, WholeTextFileInputFormat, + FixedLengthBinaryInputFormat} import org.apache.spark.partial.{ApproximateEvaluator, PartialResult} import org.apache.spark.rdd._ import org.apache.spark.scheduler._ -import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend, SparkDeploySchedulerBackend, SimrSchedulerBackend} +import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend, + SparkDeploySchedulerBackend, SimrSchedulerBackend} import org.apache.spark.scheduler.cluster.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend} import org.apache.spark.scheduler.local.LocalBackend import org.apache.spark.storage._ @@ -85,6 +94,14 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli val startTime = System.currentTimeMillis() + @volatile private var stopped: Boolean = false + + private def assertNotStopped(): Unit = { + if (stopped) { + throw new IllegalStateException("Cannot call methods on a stopped SparkContext") + } + } + /** * Create a SparkContext that loads settings from system properties (for instance, when * launching with ./bin/spark-submit). @@ -174,7 +191,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli // log out Spark Version in Spark driver log logInfo(s"Running Spark version $SPARK_VERSION") - + private[spark] val conf = config.clone() conf.validateSettings() @@ -232,7 +249,16 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli conf.set("spark.executor.id", SparkContext.DRIVER_IDENTIFIER) // Create the Spark execution environment (cache, map output tracker, etc) - private[spark] val env = SparkEnv.createDriverEnv(conf, isLocal, listenerBus) + + // This function allows components created by SparkEnv to be mocked in unit tests: + private[spark] def createSparkEnv( + conf: SparkConf, + isLocal: Boolean, + listenerBus: LiveListenerBus): SparkEnv = { + SparkEnv.createDriverEnv(conf, isLocal, listenerBus) + } + + private[spark] val env = createSparkEnv(conf, isLocal, listenerBus) SparkEnv.set(env) // Used to store a URL for each static file/jar together with the file's local timestamp @@ -271,7 +297,12 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli // the bound port to the cluster manager properly ui.foreach(_.bind()) - /** A default Hadoop Configuration for the Hadoop code (e.g. file systems) that we reuse. */ + /** + * A default Hadoop Configuration for the Hadoop code (e.g. file systems) that we reuse. + * + * '''Note:''' As it will be reused in all Hadoop RDDs, it's better not to modify it unless you + * plan to set some global configurations for all Hadoop RDDs. + */ val hadoopConfiguration = SparkHadoopUtil.get.newConfiguration(conf) // Add each JAR given through the constructor @@ -313,11 +344,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli executorEnvs ++= conf.getExecutorEnv // Set SPARK_USER for user who is running SparkContext. - val sparkUser = Option { - Option(System.getenv("SPARK_USER")).getOrElse(System.getProperty("user.name")) - }.getOrElse { - SparkContext.SPARK_UNKNOWN_USER - } + val sparkUser = Utils.getCurrentUserName() executorEnvs("SPARK_USER") = sparkUser // Create and start the scheduler @@ -379,9 +406,6 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli } executorAllocationManager.foreach(_.start()) - // At this point, all relevant SparkListeners have been registered, so begin releasing events - listenerBus.start() - private[spark] val cleaner: Option[ContextCleaner] = { if (conf.getBoolean("spark.cleaner.referenceTracking", true)) { Some(new ContextCleaner(this)) @@ -391,6 +415,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli } cleaner.foreach(_.start()) + setupAndStartListenerBus() postEnvironmentUpdate() postApplicationStart() @@ -520,12 +545,14 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli /** Distribute a local Scala collection to form an RDD. * - * @note Parallelize acts lazily. If `seq` is a mutable collection and is - * altered after the call to parallelize and before the first action on the - * RDD, the resultant RDD will reflect the modified collection. Pass a copy of - * the argument to avoid this. + * @note Parallelize acts lazily. If `seq` is a mutable collection and is altered after the call + * to parallelize and before the first action on the RDD, the resultant RDD will reflect the + * modified collection. Pass a copy of the argument to avoid this. + * @note avoid using `parallelize(Seq())` to create an empty `RDD`. Consider `emptyRDD` for an + * RDD with no partitions, or `parallelize(Seq[T]())` for an RDD of `T` with empty partitions. */ def parallelize[T: ClassTag](seq: Seq[T], numSlices: Int = defaultParallelism): RDD[T] = { + assertNotStopped() new ParallelCollectionRDD[T](this, seq, numSlices, Map[Int, Seq[String]]()) } @@ -541,6 +568,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * location preferences (hostnames of Spark nodes) for each object. * Create a new partition for each collection item. */ def makeRDD[T: ClassTag](seq: Seq[(T, Seq[String])]): RDD[T] = { + assertNotStopped() val indexToPrefs = seq.zipWithIndex.map(t => (t._2, t._1._2)).toMap new ParallelCollectionRDD[T](this, seq.map(_._1), seq.size, indexToPrefs) } @@ -550,6 +578,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * Hadoop-supported file system URI, and return it as an RDD of Strings. */ def textFile(path: String, minPartitions: Int = defaultMinPartitions): RDD[String] = { + assertNotStopped() hadoopFile(path, classOf[TextInputFormat], classOf[LongWritable], classOf[Text], minPartitions).map(pair => pair._2.toString).setName(path) } @@ -583,6 +612,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli */ def wholeTextFiles(path: String, minPartitions: Int = defaultMinPartitions): RDD[(String, String)] = { + assertNotStopped() val job = new NewHadoopJob(hadoopConfiguration) NewFileInputFormat.addInputPath(job, new Path(path)) val updateConf = job.getConfiguration @@ -628,6 +658,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli @Experimental def binaryFiles(path: String, minPartitions: Int = defaultMinPartitions): RDD[(String, PortableDataStream)] = { + assertNotStopped() val job = new NewHadoopJob(hadoopConfiguration) NewFileInputFormat.addInputPath(job, new Path(path)) val updateConf = job.getConfiguration @@ -645,6 +676,9 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * * Load data from a flat binary file, assuming the length of each record is constant. * + * '''Note:''' We ensure that the byte array for each record in the resulting RDD + * has the provided record length. + * * @param path Directory to the input data files * @param recordLength The length at which to split the records * @return An RDD of data with values, represented as byte arrays @@ -652,13 +686,18 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli @Experimental def binaryRecords(path: String, recordLength: Int, conf: Configuration = hadoopConfiguration) : RDD[Array[Byte]] = { + assertNotStopped() conf.setInt(FixedLengthBinaryInputFormat.RECORD_LENGTH_PROPERTY, recordLength) val br = newAPIHadoopFile[LongWritable, BytesWritable, FixedLengthBinaryInputFormat](path, classOf[FixedLengthBinaryInputFormat], classOf[LongWritable], classOf[BytesWritable], conf=conf) - val data = br.map{ case (k, v) => v.getBytes} + val data = br.map { case (k, v) => + val bytes = v.getBytes + assert(bytes.length == recordLength, "Byte array does not have correct length") + bytes + } data } @@ -667,16 +706,20 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * necessary info (e.g. file name for a filesystem-based dataset, table name for HyperTable), * using the older MapReduce API (`org.apache.hadoop.mapred`). * - * @param conf JobConf for setting up the dataset + * @param conf JobConf for setting up the dataset. Note: This will be put into a Broadcast. + * Therefore if you plan to reuse this conf to create multiple RDDs, you need to make + * sure you won't modify the conf. A safe approach is always creating a new conf for + * a new RDD. * @param inputFormatClass Class of the InputFormat * @param keyClass Class of the keys * @param valueClass Class of the values * @param minPartitions Minimum number of Hadoop Splits to generate. * * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each - * record, directly caching the returned RDD will create many references to the same object. - * If you plan to directly cache Hadoop writable objects, you should first copy them using - * a `map` function. + * record, directly caching the returned RDD or directly passing it to an aggregation or shuffle + * operation will create many references to the same object. + * If you plan to directly cache, sort, or aggregate Hadoop writable objects, you should first + * copy them using a `map` function. */ def hadoopRDD[K, V]( conf: JobConf, @@ -685,18 +728,20 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli valueClass: Class[V], minPartitions: Int = defaultMinPartitions ): RDD[(K, V)] = { + assertNotStopped() // Add necessary security credentials to the JobConf before broadcasting it. SparkHadoopUtil.get.addCredentials(conf) new HadoopRDD(this, conf, inputFormatClass, keyClass, valueClass, minPartitions) } /** Get an RDD for a Hadoop file with an arbitrary InputFormat - * - * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each - * record, directly caching the returned RDD will create many references to the same object. - * If you plan to directly cache Hadoop writable objects, you should first copy them using - * a `map` function. - * */ + * + * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each + * record, directly caching the returned RDD or directly passing it to an aggregation or shuffle + * operation will create many references to the same object. + * If you plan to directly cache, sort, or aggregate Hadoop writable objects, you should first + * copy them using a `map` function. + */ def hadoopFile[K, V]( path: String, inputFormatClass: Class[_ <: InputFormat[K, V]], @@ -704,6 +749,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli valueClass: Class[V], minPartitions: Int = defaultMinPartitions ): RDD[(K, V)] = { + assertNotStopped() // A Hadoop configuration can be about 10 KB, which is pretty big, so broadcast it. val confBroadcast = broadcast(new SerializableWritable(hadoopConfiguration)) val setInputPathsFunc = (jobConf: JobConf) => FileInputFormat.setInputPaths(jobConf, path) @@ -726,9 +772,10 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * }}} * * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each - * record, directly caching the returned RDD will create many references to the same object. - * If you plan to directly cache Hadoop writable objects, you should first copy them using - * a `map` function. + * record, directly caching the returned RDD or directly passing it to an aggregation or shuffle + * operation will create many references to the same object. + * If you plan to directly cache, sort, or aggregate Hadoop writable objects, you should first + * copy them using a `map` function. */ def hadoopFile[K, V, F <: InputFormat[K, V]] (path: String, minPartitions: Int) @@ -749,9 +796,10 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * }}} * * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each - * record, directly caching the returned RDD will create many references to the same object. - * If you plan to directly cache Hadoop writable objects, you should first copy them using - * a `map` function. + * record, directly caching the returned RDD or directly passing it to an aggregation or shuffle + * operation will create many references to the same object. + * If you plan to directly cache, sort, or aggregate Hadoop writable objects, you should first + * copy them using a `map` function. */ def hadoopFile[K, V, F <: InputFormat[K, V]](path: String) (implicit km: ClassTag[K], vm: ClassTag[V], fm: ClassTag[F]): RDD[(K, V)] = @@ -773,9 +821,10 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * and extra configuration options to pass to the input format. * * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each - * record, directly caching the returned RDD will create many references to the same object. - * If you plan to directly cache Hadoop writable objects, you should first copy them using - * a `map` function. + * record, directly caching the returned RDD or directly passing it to an aggregation or shuffle + * operation will create many references to the same object. + * If you plan to directly cache, sort, or aggregate Hadoop writable objects, you should first + * copy them using a `map` function. */ def newAPIHadoopFile[K, V, F <: NewInputFormat[K, V]]( path: String, @@ -783,6 +832,9 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli kClass: Class[K], vClass: Class[V], conf: Configuration = hadoopConfiguration): RDD[(K, V)] = { + assertNotStopped() + // The call to new NewHadoopJob automatically adds security credentials to conf, + // so we don't need to explicitly add them ourselves val job = new NewHadoopJob(conf) NewFileInputFormat.addInputPath(job, new Path(path)) val updatedConf = job.getConfiguration @@ -793,31 +845,46 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * Get an RDD for a given Hadoop file with an arbitrary new API InputFormat * and extra configuration options to pass to the input format. * + * @param conf Configuration for setting up the dataset. Note: This will be put into a Broadcast. + * Therefore if you plan to reuse this conf to create multiple RDDs, you need to make + * sure you won't modify the conf. A safe approach is always creating a new conf for + * a new RDD. + * @param fClass Class of the InputFormat + * @param kClass Class of the keys + * @param vClass Class of the values + * * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each - * record, directly caching the returned RDD will create many references to the same object. - * If you plan to directly cache Hadoop writable objects, you should first copy them using - * a `map` function. + * record, directly caching the returned RDD or directly passing it to an aggregation or shuffle + * operation will create many references to the same object. + * If you plan to directly cache, sort, or aggregate Hadoop writable objects, you should first + * copy them using a `map` function. */ def newAPIHadoopRDD[K, V, F <: NewInputFormat[K, V]]( conf: Configuration = hadoopConfiguration, fClass: Class[F], kClass: Class[K], vClass: Class[V]): RDD[(K, V)] = { - new NewHadoopRDD(this, fClass, kClass, vClass, conf) + assertNotStopped() + // Add necessary security credentials to the JobConf. Required to access secure HDFS. + val jconf = new JobConf(conf) + SparkHadoopUtil.get.addCredentials(jconf) + new NewHadoopRDD(this, fClass, kClass, vClass, jconf) } /** Get an RDD for a Hadoop SequenceFile with given key and value types. * * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each - * record, directly caching the returned RDD will create many references to the same object. - * If you plan to directly cache Hadoop writable objects, you should first copy them using - * a `map` function. + * record, directly caching the returned RDD or directly passing it to an aggregation or shuffle + * operation will create many references to the same object. + * If you plan to directly cache, sort, or aggregate Hadoop writable objects, you should first + * copy them using a `map` function. */ def sequenceFile[K, V](path: String, keyClass: Class[K], valueClass: Class[V], minPartitions: Int ): RDD[(K, V)] = { + assertNotStopped() val inputFormatClass = classOf[SequenceFileInputFormat[K, V]] hadoopFile(path, inputFormatClass, keyClass, valueClass, minPartitions) } @@ -825,13 +892,15 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli /** Get an RDD for a Hadoop SequenceFile with given key and value types. * * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each - * record, directly caching the returned RDD will create many references to the same object. - * If you plan to directly cache Hadoop writable objects, you should first copy them using - * a `map` function. + * record, directly caching the returned RDD or directly passing it to an aggregation or shuffle + * operation will create many references to the same object. + * If you plan to directly cache, sort, or aggregate Hadoop writable objects, you should first + * copy them using a `map` function. * */ - def sequenceFile[K, V](path: String, keyClass: Class[K], valueClass: Class[V] - ): RDD[(K, V)] = + def sequenceFile[K, V](path: String, keyClass: Class[K], valueClass: Class[V]): RDD[(K, V)] = { + assertNotStopped() sequenceFile(path, keyClass, valueClass, defaultMinPartitions) + } /** * Version of sequenceFile() for types implicitly convertible to Writables through a @@ -850,15 +919,17 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * allow it to figure out the Writable class to use in the subclass case. * * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each - * record, directly caching the returned RDD will create many references to the same object. - * If you plan to directly cache Hadoop writable objects, you should first copy them using - * a `map` function. + * record, directly caching the returned RDD or directly passing it to an aggregation or shuffle + * operation will create many references to the same object. + * If you plan to directly cache, sort, or aggregate Hadoop writable objects, you should first + * copy them using a `map` function. */ def sequenceFile[K, V] (path: String, minPartitions: Int = defaultMinPartitions) (implicit km: ClassTag[K], vm: ClassTag[V], kcf: () => WritableConverter[K], vcf: () => WritableConverter[V]) : RDD[(K, V)] = { + assertNotStopped() val kc = kcf() val vc = vcf() val format = classOf[SequenceFileInputFormat[Writable, Writable]] @@ -880,6 +951,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli path: String, minPartitions: Int = defaultMinPartitions ): RDD[T] = { + assertNotStopped() sequenceFile(path, classOf[NullWritable], classOf[BytesWritable], minPartitions) .flatMap(x => Utils.deserialize[Array[T]](x._2.getBytes, Utils.getContextOrSparkClassLoader)) } @@ -891,11 +963,18 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli } /** Build the union of a list of RDDs. */ - def union[T: ClassTag](rdds: Seq[RDD[T]]): RDD[T] = new UnionRDD(this, rdds) + def union[T: ClassTag](rdds: Seq[RDD[T]]): RDD[T] = { + val partitioners = rdds.flatMap(_.partitioner).toSet + if (partitioners.size == 1) { + new PartitionerAwareUnionRDD(this, rdds) + } else { + new UnionRDD(this, rdds) + } + } /** Build the union of a list of RDDs passed as variable-length arguments. */ def union[T: ClassTag](first: RDD[T], rest: RDD[T]*): RDD[T] = - new UnionRDD(this, Seq(first) ++ rest) + union(Seq(first) ++ rest) /** Get an RDD that has no partitions or elements. */ def emptyRDD[T: ClassTag] = new EmptyRDD[T](this) @@ -907,7 +986,11 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * values to using the `+=` method. Only the driver can access the accumulator's `value`. */ def accumulator[T](initialValue: T)(implicit param: AccumulatorParam[T]) = - new Accumulator(initialValue, param) + { + val acc = new Accumulator(initialValue, param) + cleaner.foreach(_.registerAccumulatorForCleanup(acc)) + acc + } /** * Create an [[org.apache.spark.Accumulator]] variable of a given type, with a name for display @@ -915,7 +998,9 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * driver can access the accumulator's `value`. */ def accumulator[T](initialValue: T, name: String)(implicit param: AccumulatorParam[T]) = { - new Accumulator(initialValue, param, Some(name)) + val acc = new Accumulator(initialValue, param, Some(name)) + cleaner.foreach(_.registerAccumulatorForCleanup(acc)) + acc } /** @@ -924,8 +1009,11 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * @tparam R accumulator result type * @tparam T type that can be added to the accumulator */ - def accumulable[R, T](initialValue: R)(implicit param: AccumulableParam[R, T]) = - new Accumulable(initialValue, param) + def accumulable[R, T](initialValue: R)(implicit param: AccumulableParam[R, T]) = { + val acc = new Accumulable(initialValue, param) + cleaner.foreach(_.registerAccumulatorForCleanup(acc)) + acc + } /** * Create an [[org.apache.spark.Accumulable]] shared variable, with a name for display in the @@ -934,8 +1022,11 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * @tparam R accumulator result type * @tparam T type that can be added to the accumulator */ - def accumulable[R, T](initialValue: R, name: String)(implicit param: AccumulableParam[R, T]) = - new Accumulable(initialValue, param, Some(name)) + def accumulable[R, T](initialValue: R, name: String)(implicit param: AccumulableParam[R, T]) = { + val acc = new Accumulable(initialValue, param, Some(name)) + cleaner.foreach(_.registerAccumulatorForCleanup(acc)) + acc + } /** * Create an accumulator from a "mutable collection" type. @@ -946,7 +1037,9 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli def accumulableCollection[R <% Growable[T] with TraversableOnce[T] with Serializable: ClassTag, T] (initialValue: R): Accumulable[R, T] = { val param = new GrowableAccumulableParam[R,T] - new Accumulable(initialValue, param) + val acc = new Accumulable(initialValue, param) + cleaner.foreach(_.registerAccumulatorForCleanup(acc)) + acc } /** @@ -955,6 +1048,13 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * The variable will be sent to each cluster only once. */ def broadcast[T: ClassTag](value: T): Broadcast[T] = { + assertNotStopped() + if (classOf[RDD[_]].isAssignableFrom(classTag[T].runtimeClass)) { + // This is a warning instead of an exception in order to avoid breaking user programs that + // might have created RDD broadcast variables but not used them: + logWarning("Can not directly broadcast RDDs; instead, call collect() and " + + "broadcast the result (see SPARK-5063)") + } val bc = env.broadcastManager.newBroadcast[T](value, isLocal) val callSite = getCallSite logInfo("Created broadcast " + bc.id + " from " + callSite.shortForm) @@ -968,12 +1068,48 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * filesystems), or an HTTP, HTTPS or FTP URI. To access the file in Spark jobs, * use `SparkFiles.get(fileName)` to find its download location. */ - def addFile(path: String) { + def addFile(path: String): Unit = { + addFile(path, false) + } + + /** + * Add a file to be downloaded with this Spark job on every node. + * The `path` passed can be either a local file, a file in HDFS (or other Hadoop-supported + * filesystems), or an HTTP, HTTPS or FTP URI. To access the file in Spark jobs, + * use `SparkFiles.get(fileName)` to find its download location. + * + * A directory can be given if the recursive option is set to true. Currently directories are only + * supported for Hadoop-supported filesystems. + */ + def addFile(path: String, recursive: Boolean): Unit = { val uri = new URI(path) - val key = uri.getScheme match { - case null | "file" => env.httpFileServer.addFile(new File(uri.getPath)) - case "local" => "file:" + uri.getPath - case _ => path + val schemeCorrectedPath = uri.getScheme match { + case null | "local" => "file:" + uri.getPath + case _ => path + } + + val hadoopPath = new Path(schemeCorrectedPath) + val scheme = new URI(schemeCorrectedPath).getScheme + if (!Array("http", "https", "ftp").contains(scheme)) { + val fs = hadoopPath.getFileSystem(hadoopConfiguration) + if (!fs.exists(hadoopPath)) { + throw new FileNotFoundException(s"Added file $hadoopPath does not exist.") + } + val isDir = fs.isDirectory(hadoopPath) + if (!isLocal && scheme == "file" && isDir) { + throw new SparkException(s"addFile does not support local directories when not running " + + "local mode.") + } + if (!recursive && isDir) { + throw new SparkException(s"Added file $hadoopPath is a directory and recursive is not " + + "turned on.") + } + } + + val key = if (!isLocal && scheme == "file") { + env.httpFileServer.addFile(new File(uri.getPath)) + } else { + schemeCorrectedPath } val timestamp = System.currentTimeMillis addedFiles(key) = timestamp @@ -995,10 +1131,27 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli listenerBus.addListener(listener) } + /** + * Express a preference to the cluster manager for a given total number of executors. + * This can result in canceling pending requests or filing additional requests. + * This is currently only supported in YARN mode. Return whether the request is received. + */ + private[spark] override def requestTotalExecutors(numExecutors: Int): Boolean = { + assert(master.contains("yarn") || dynamicAllocationTesting, + "Requesting executors is currently only supported in YARN mode") + schedulerBackend match { + case b: CoarseGrainedSchedulerBackend => + b.requestTotalExecutors(numExecutors) + case _ => + logWarning("Requesting executors is only supported in coarse-grained mode") + false + } + } + /** * :: DeveloperApi :: * Request an additional number of executors from the cluster manager. - * This is currently only supported in Yarn mode. Return whether the request is received. + * This is currently only supported in YARN mode. Return whether the request is received. */ @DeveloperApi override def requestExecutors(numAdditionalExecutors: Int): Boolean = { @@ -1016,7 +1169,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli /** * :: DeveloperApi :: * Request that the cluster manager kill the specified executors. - * This is currently only supported in Yarn mode. Return whether the request is received. + * This is currently only supported in YARN mode. Return whether the request is received. */ @DeveloperApi override def killExecutors(executorIds: Seq[String]): Boolean = { @@ -1047,6 +1200,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * memory available for caching. */ def getExecutorMemoryStatus: Map[String, (Long, Long)] = { + assertNotStopped() env.blockManager.master.getMemoryStatus.map { case(blockManagerId, mem) => (blockManagerId.host + ":" + blockManagerId.port, mem) } @@ -1059,6 +1213,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli */ @DeveloperApi def getRDDStorageInfo: Array[RDDInfo] = { + assertNotStopped() val rddInfos = persistentRdds.values.map(RDDInfo.fromRdd).toArray StorageUtils.updateRddInfo(rddInfos, getExecutorStorageStatus) rddInfos.filter(_.isCached) @@ -1076,6 +1231,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli */ @DeveloperApi def getExecutorStorageStatus: Array[StorageStatus] = { + assertNotStopped() env.blockManager.master.getStorageStatus } @@ -1085,6 +1241,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli */ @DeveloperApi def getAllPools: Seq[Schedulable] = { + assertNotStopped() // TODO(xiajunluan): We should take nested pools into account taskScheduler.rootPool.schedulableQueue.toSeq } @@ -1095,6 +1252,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli */ @DeveloperApi def getPoolForName(pool: String): Option[Schedulable] = { + assertNotStopped() Option(taskScheduler.rootPool.schedulableNameToSchedulable.get(pool)) } @@ -1102,6 +1260,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * Return current scheduling mode */ def getSchedulingMode: SchedulingMode.SchedulingMode = { + assertNotStopped() taskScheduler.schedulingMode } @@ -1176,7 +1335,19 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli null } } else { - env.httpFileServer.addJar(new File(uri.getPath)) + try { + env.httpFileServer.addJar(new File(uri.getPath)) + } catch { + case exc: FileNotFoundException => + logError(s"Jar not found at $path") + null + case e: Exception => + // For now just log an error but allow to go through so spark examples work. + // The spark examples don't really need the jar distributed since its also + // the app jar. + logError("Error adding jar (" + e + "), was the --addJars option used?") + null + } } // A JAR file which exists locally on every worker node case "local" => @@ -1207,16 +1378,15 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli SparkContext.SPARK_CONTEXT_CONSTRUCTOR_LOCK.synchronized { postApplicationEnd() ui.foreach(_.stop()) - // Do this only if not stopped already - best case effort. - // prevent NPE if stopped more than once. - val dagSchedulerCopy = dagScheduler - dagScheduler = null - if (dagSchedulerCopy != null) { + if (!stopped) { + stopped = true env.metricsSystem.report() metadataCleaner.cancel() env.actorSystem.stop(heartbeatReceiver) cleaner.foreach(_.stop()) - dagSchedulerCopy.stop() + dagScheduler.stop() + dagScheduler = null + progressBar.foreach(_.stop()) taskScheduler = null // TODO: Cache.stop()? env.stop() @@ -1290,12 +1460,15 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli partitions: Seq[Int], allowLocal: Boolean, resultHandler: (Int, U) => Unit) { - if (dagScheduler == null) { - throw new SparkException("SparkContext has been shutdown") + if (stopped) { + throw new IllegalStateException("SparkContext has been shutdown") } val callSite = getCallSite val cleanedFunc = clean(func) logInfo("Starting job: " + callSite.shortForm) + if (conf.getBoolean("spark.logLineage", false)) { + logInfo("RDD's recursive dependencies:\n" + rdd.toDebugString) + } dagScheduler.runJob(rdd, cleanedFunc, partitions, callSite, allowLocal, resultHandler, localProperties.get) progressBar.foreach(_.finishAll()) @@ -1378,6 +1551,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli func: (TaskContext, Iterator[T]) => U, evaluator: ApproximateEvaluator[U, R], timeout: Long): PartialResult[R] = { + assertNotStopped() val callSite = getCallSite logInfo("Starting job: " + callSite.shortForm) val start = System.nanoTime @@ -1400,6 +1574,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli resultHandler: (Int, U) => Unit, resultFunc: => R): SimpleFutureAction[R] = { + assertNotStopped() val cleanF = clean(processPartition) val callSite = getCallSite val waiter = dagScheduler.submitJob( @@ -1418,11 +1593,13 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * for more information. */ def cancelJobGroup(groupId: String) { + assertNotStopped() dagScheduler.cancelJobGroup(groupId) } /** Cancel all jobs that have been scheduled or are running. */ def cancelAllJobs() { + assertNotStopped() dagScheduler.cancelAllJobs() } @@ -1469,13 +1646,20 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli def getCheckpointDir = checkpointDir /** Default level of parallelism to use when not given by user (e.g. parallelize and makeRDD). */ - def defaultParallelism: Int = taskScheduler.defaultParallelism + def defaultParallelism: Int = { + assertNotStopped() + taskScheduler.defaultParallelism + } /** Default min number of partitions for Hadoop RDDs when not given by user */ @deprecated("use defaultMinPartitions", "1.0.0") def defaultMinSplits: Int = math.min(defaultParallelism, 2) - /** Default min number of partitions for Hadoop RDDs when not given by user */ + /** + * Default min number of partitions for Hadoop RDDs when not given by user + * Notice that we use math.min so the "defaultMinPartitions" cannot be higher than 2. + * The reasons for this are discussed in https://github.com/mesos/spark/pull/718 + */ def defaultMinPartitions: Int = math.min(defaultParallelism, 2) private val nextShuffleId = new AtomicInteger(0) @@ -1487,6 +1671,58 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli /** Register a new RDD, returning its RDD ID */ private[spark] def newRddId(): Int = nextRddId.getAndIncrement() + /** + * Registers listeners specified in spark.extraListeners, then starts the listener bus. + * This should be called after all internal listeners have been registered with the listener bus + * (e.g. after the web UI and event logging listeners have been registered). + */ + private def setupAndStartListenerBus(): Unit = { + // Use reflection to instantiate listeners specified via `spark.extraListeners` + try { + val listenerClassNames: Seq[String] = + conf.get("spark.extraListeners", "").split(',').map(_.trim).filter(_ != "") + for (className <- listenerClassNames) { + // Use reflection to find the right constructor + val constructors = { + val listenerClass = Class.forName(className) + listenerClass.getConstructors.asInstanceOf[Array[Constructor[_ <: SparkListener]]] + } + val constructorTakingSparkConf = constructors.find { c => + c.getParameterTypes.sameElements(Array(classOf[SparkConf])) + } + lazy val zeroArgumentConstructor = constructors.find { c => + c.getParameterTypes.isEmpty + } + val listener: SparkListener = { + if (constructorTakingSparkConf.isDefined) { + constructorTakingSparkConf.get.newInstance(conf) + } else if (zeroArgumentConstructor.isDefined) { + zeroArgumentConstructor.get.newInstance() + } else { + throw new SparkException( + s"$className did not have a zero-argument constructor or a" + + " single-argument constructor that accepts SparkConf. Note: if the class is" + + " defined inside of another Scala class, then its constructors may accept an" + + " implicit parameter that references the enclosing class; in this case, you must" + + " define the listener as a top-level class in order to prevent this extra" + + " parameter from breaking Spark's ability to find a valid constructor.") + } + } + listenerBus.addListener(listener) + logInfo(s"Registered listener $className") + } + } catch { + case e: Exception => + try { + stop() + } finally { + throw new SparkException(s"Exception when registering SparkListener", e) + } + } + + listenerBus.start() + } + /** Post the application start event */ private def postApplicationStart() { // Note: this code assumes that the task scheduler has been initialized and has contacted @@ -1506,8 +1742,8 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli val schedulingMode = getSchedulingMode.toString val addedJarPaths = addedJars.keys.toSeq val addedFilePaths = addedFiles.keys.toSeq - val environmentDetails = - SparkEnv.environmentDetails(conf, schedulingMode, addedJarPaths, addedFilePaths) + val environmentDetails = SparkEnv.environmentDetails(conf, schedulingMode, addedJarPaths, + addedFilePaths) val environmentUpdate = SparkListenerEnvironmentUpdate(environmentDetails) listenerBus.post(environmentUpdate) } @@ -1637,8 +1873,6 @@ object SparkContext extends Logging { private[spark] val SPARK_JOB_INTERRUPT_ON_CANCEL = "spark.job.interruptOnCancel" - private[spark] val SPARK_UNKNOWN_USER = "" - private[spark] val DRIVER_IDENTIFIER = "" // The following deprecated objects have already been copied to `object AccumulatorParam` to @@ -1692,8 +1926,14 @@ object SparkContext extends Logging { @deprecated("Replaced by implicit functions in the RDD companion object. This is " + "kept here only for backward compatibility.", "1.3.0") def rddToSequenceFileRDDFunctions[K <% Writable: ClassTag, V <% Writable: ClassTag]( - rdd: RDD[(K, V)]) = + rdd: RDD[(K, V)]) = { + val kf = implicitly[K => Writable] + val vf = implicitly[V => Writable] + // Set the Writable class to null and `SequenceFileRDDFunctions` will use Reflection to get it + implicit val keyWritableFactory = new WritableFactory[K](_ => null, kf) + implicit val valueWritableFactory = new WritableFactory[V](_ => null, vf) RDD.rddToSequenceFileRDDFunctions(rdd) + } @deprecated("Replaced by implicit functions in the RDD companion object. This is " + "kept here only for backward compatibility.", "1.3.0") @@ -1710,20 +1950,35 @@ object SparkContext extends Logging { def numericRDDToDoubleRDDFunctions[T](rdd: RDD[T])(implicit num: Numeric[T]) = RDD.numericRDDToDoubleRDDFunctions(rdd) - // Implicit conversions to common Writable types, for saveAsSequenceFile + // The following deprecated functions have already been moved to `object WritableFactory` to + // make the compiler find them automatically. They are still kept here for backward compatibility. + @deprecated("Replaced by implicit functions in the WritableFactory companion object. This is " + + "kept here only for backward compatibility.", "1.3.0") implicit def intToIntWritable(i: Int): IntWritable = new IntWritable(i) + @deprecated("Replaced by implicit functions in the WritableFactory companion object. This is " + + "kept here only for backward compatibility.", "1.3.0") implicit def longToLongWritable(l: Long): LongWritable = new LongWritable(l) + @deprecated("Replaced by implicit functions in the WritableFactory companion object. This is " + + "kept here only for backward compatibility.", "1.3.0") implicit def floatToFloatWritable(f: Float): FloatWritable = new FloatWritable(f) + @deprecated("Replaced by implicit functions in the WritableFactory companion object. This is " + + "kept here only for backward compatibility.", "1.3.0") implicit def doubleToDoubleWritable(d: Double): DoubleWritable = new DoubleWritable(d) + @deprecated("Replaced by implicit functions in the WritableFactory companion object. This is " + + "kept here only for backward compatibility.", "1.3.0") implicit def boolToBoolWritable (b: Boolean): BooleanWritable = new BooleanWritable(b) + @deprecated("Replaced by implicit functions in the WritableFactory companion object. This is " + + "kept here only for backward compatibility.", "1.3.0") implicit def bytesToBytesWritable (aob: Array[Byte]): BytesWritable = new BytesWritable(aob) + @deprecated("Replaced by implicit functions in the WritableFactory companion object. This is " + + "kept here only for backward compatibility.", "1.3.0") implicit def stringToText(s: String): Text = new Text(s) private implicit def arrayToArrayWritable[T <% Writable: ClassTag](arr: Traversable[T]) @@ -1902,7 +2157,7 @@ object SparkContext extends Logging { val scheduler = new TaskSchedulerImpl(sc) val localCluster = new LocalSparkCluster( - numSlaves.toInt, coresPerSlave.toInt, memoryPerSlaveInt) + numSlaves.toInt, coresPerSlave.toInt, memoryPerSlaveInt, sc.conf) val masterUrls = localCluster.start() val backend = new SparkDeploySchedulerBackend(scheduler, sc, masterUrls) scheduler.initialize(backend) @@ -1943,7 +2198,7 @@ object SparkContext extends Logging { case "yarn-client" => val scheduler = try { val clazz = - Class.forName("org.apache.spark.scheduler.cluster.YarnClientClusterScheduler") + Class.forName("org.apache.spark.scheduler.cluster.YarnScheduler") val cons = clazz.getConstructor(classOf[SparkContext]) cons.newInstance(sc).asInstanceOf[TaskSchedulerImpl] @@ -2013,7 +2268,7 @@ object WritableConverter { new WritableConverter[T](_ => wClass, x => convert(x.asInstanceOf[W])) } - // The following implicit functions were in SparkContext before 1.2 and users had to + // The following implicit functions were in SparkContext before 1.3 and users had to // `import SparkContext._` to enable them. Now we move them here to make the compiler find // them automatically. However, we still keep the old functions in SparkContext for backward // compatibility and forward to the following functions directly. @@ -2046,3 +2301,46 @@ object WritableConverter { implicit def writableWritableConverter[T <: Writable](): WritableConverter[T] = new WritableConverter[T](_.runtimeClass.asInstanceOf[Class[T]], _.asInstanceOf[T]) } + +/** + * A class encapsulating how to convert some type T to Writable. It stores both the Writable class + * corresponding to T (e.g. IntWritable for Int) and a function for doing the conversion. + * The Writable class will be used in `SequenceFileRDDFunctions`. + */ +private[spark] class WritableFactory[T]( + val writableClass: ClassTag[T] => Class[_ <: Writable], + val convert: T => Writable) extends Serializable + +object WritableFactory { + + private[spark] def simpleWritableFactory[T: ClassTag, W <: Writable : ClassTag](convert: T => W) + : WritableFactory[T] = { + val writableClass = implicitly[ClassTag[W]].runtimeClass.asInstanceOf[Class[W]] + new WritableFactory[T](_ => writableClass, convert) + } + + implicit def intWritableFactory: WritableFactory[Int] = + simpleWritableFactory(new IntWritable(_)) + + implicit def longWritableFactory: WritableFactory[Long] = + simpleWritableFactory(new LongWritable(_)) + + implicit def floatWritableFactory: WritableFactory[Float] = + simpleWritableFactory(new FloatWritable(_)) + + implicit def doubleWritableFactory: WritableFactory[Double] = + simpleWritableFactory(new DoubleWritable(_)) + + implicit def booleanWritableFactory: WritableFactory[Boolean] = + simpleWritableFactory(new BooleanWritable(_)) + + implicit def bytesWritableFactory: WritableFactory[Array[Byte]] = + simpleWritableFactory(new BytesWritable(_)) + + implicit def stringWritableFactory: WritableFactory[String] = + simpleWritableFactory(new Text(_)) + + implicit def writableWritableFactory[T <: Writable: ClassTag]: WritableFactory[T] = + simpleWritableFactory(w => w) + +} diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 4d418037bd33f..2a0c7e756dd3a 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -34,7 +34,8 @@ import org.apache.spark.metrics.MetricsSystem import org.apache.spark.network.BlockTransferService import org.apache.spark.network.netty.NettyBlockTransferService import org.apache.spark.network.nio.NioBlockTransferService -import org.apache.spark.scheduler.LiveListenerBus +import org.apache.spark.scheduler.{OutputCommitCoordinator, LiveListenerBus} +import org.apache.spark.scheduler.OutputCommitCoordinator.OutputCommitCoordinatorActor import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.{ShuffleMemoryManager, ShuffleManager} import org.apache.spark.storage._ @@ -67,6 +68,7 @@ class SparkEnv ( val sparkFilesDir: String, val metricsSystem: MetricsSystem, val shuffleMemoryManager: ShuffleMemoryManager, + val outputCommitCoordinator: OutputCommitCoordinator, val conf: SparkConf) extends Logging { private[spark] var isStopped = false @@ -76,6 +78,8 @@ class SparkEnv ( // (e.g., HadoopFileRDD uses this to cache JobConfs and InputFormats). private[spark] val hadoopJobMetadata = new MapMaker().softValues().makeMap[String, Any]() + private var driverTmpDirToDelete: Option[String] = None + private[spark] def stop() { isStopped = true pythonWorkers.foreach { case(key, worker) => worker.stop() } @@ -86,6 +90,7 @@ class SparkEnv ( blockManager.stop() blockManager.master.stop() metricsSystem.stop() + outputCommitCoordinator.stop() actorSystem.shutdown() // Unfortunately Akka's awaitTermination doesn't actually wait for the Netty server to shut // down, but let's call it anyway in case it gets fixed in a later release @@ -93,6 +98,22 @@ class SparkEnv ( // actorSystem.awaitTermination() // Note that blockTransferService is stopped by BlockManager since it is started by it. + + // If we only stop sc, but the driver process still run as a services then we need to delete + // the tmp dir, if not, it will create too many tmp dirs. + // We only need to delete the tmp dir create by driver, because sparkFilesDir is point to the + // current working dir in executor which we do not need to delete. + driverTmpDirToDelete match { + case Some(path) => { + try { + Utils.deleteRecursively(new File(path)) + } catch { + case e: Exception => + logWarning(s"Exception while deleting Spark temp dir: $path", e) + } + } + case None => // We just need to delete tmp dir created by driver, so do nothing on executor + } } private[spark] @@ -151,7 +172,8 @@ object SparkEnv extends Logging { private[spark] def createDriverEnv( conf: SparkConf, isLocal: Boolean, - listenerBus: LiveListenerBus): SparkEnv = { + listenerBus: LiveListenerBus, + mockOutputCommitCoordinator: Option[OutputCommitCoordinator] = None): SparkEnv = { assert(conf.contains("spark.driver.host"), "spark.driver.host is not set on the driver!") assert(conf.contains("spark.driver.port"), "spark.driver.port is not set on the driver!") val hostname = conf.get("spark.driver.host") @@ -163,7 +185,8 @@ object SparkEnv extends Logging { port, isDriver = true, isLocal = isLocal, - listenerBus = listenerBus + listenerBus = listenerBus, + mockOutputCommitCoordinator = mockOutputCommitCoordinator ) } @@ -202,7 +225,8 @@ object SparkEnv extends Logging { isDriver: Boolean, isLocal: Boolean, listenerBus: LiveListenerBus = null, - numUsableCores: Int = 0): SparkEnv = { + numUsableCores: Int = 0, + mockOutputCommitCoordinator: Option[OutputCommitCoordinator] = None): SparkEnv = { // Listener bus is only used on the driver if (isDriver) { @@ -326,6 +350,10 @@ object SparkEnv extends Logging { // Then we can start the metrics system. MetricsSystem.createMetricsSystem("driver", conf, securityManager) } else { + // We need to set the executor ID before the MetricsSystem is created because sources and + // sinks specified in the metrics configuration file will want to incorporate this executor's + // ID into the metrics they report. + conf.set("spark.executor.id", executorId) val ms = MetricsSystem.createMetricsSystem("executor", conf, securityManager) ms.start() ms @@ -335,7 +363,7 @@ object SparkEnv extends Logging { // this is a temporary directory; in distributed mode, this is the executor's current working // directory. val sparkFilesDir: String = if (isDriver) { - Utils.createTempDir().getAbsolutePath + Utils.createTempDir(Utils.getLocalDir(conf), "userFiles").getAbsolutePath } else { "." } @@ -346,7 +374,14 @@ object SparkEnv extends Logging { "levels using the RDD.persist() method instead.") } - new SparkEnv( + val outputCommitCoordinator = mockOutputCommitCoordinator.getOrElse { + new OutputCommitCoordinator(conf) + } + val outputCommitCoordinatorActor = registerOrLookup("OutputCommitCoordinator", + new OutputCommitCoordinatorActor(outputCommitCoordinator)) + outputCommitCoordinator.coordinatorActor = Some(outputCommitCoordinatorActor) + + val envInstance = new SparkEnv( executorId, actorSystem, serializer, @@ -362,7 +397,17 @@ object SparkEnv extends Logging { sparkFilesDir, metricsSystem, shuffleMemoryManager, + outputCommitCoordinator, conf) + + // Add a reference to tmp dir created by driver, we will delete this tmp dir when stop() is + // called, and we only need to do it for driver. Because driver may run as a service, and if we + // don't delete this tmp dir when sc is stopped, then will create too many tmp dirs. + if (isDriver) { + envInstance.driverTmpDirToDelete = Some(sparkFilesDir) + } + + envInstance } /** diff --git a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala index 40237596570de..6eb4537d10477 100644 --- a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala +++ b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala @@ -26,6 +26,7 @@ import org.apache.hadoop.mapred._ import org.apache.hadoop.fs.FileSystem import org.apache.hadoop.fs.Path +import org.apache.spark.executor.CommitDeniedException import org.apache.spark.mapred.SparkHadoopMapRedUtil import org.apache.spark.rdd.HadoopRDD @@ -105,24 +106,56 @@ class SparkHadoopWriter(@transient jobConf: JobConf) def commit() { val taCtxt = getTaskContext() val cmtr = getOutputCommitter() - if (cmtr.needsTaskCommit(taCtxt)) { + + // Called after we have decided to commit + def performCommit(): Unit = { try { cmtr.commitTask(taCtxt) - logInfo (taID + ": Committed") + logInfo (s"$taID: Committed") } catch { - case e: IOException => { + case e: IOException => logError("Error committing the output of task: " + taID.value, e) cmtr.abortTask(taCtxt) throw e + } + } + + // First, check whether the task's output has already been committed by some other attempt + if (cmtr.needsTaskCommit(taCtxt)) { + // The task output needs to be committed, but we don't know whether some other task attempt + // might be racing to commit the same output partition. Therefore, coordinate with the driver + // in order to determine whether this attempt can commit (see SPARK-4879). + val shouldCoordinateWithDriver: Boolean = { + val sparkConf = SparkEnv.get.conf + // We only need to coordinate with the driver if there are multiple concurrent task + // attempts, which should only occur if speculation is enabled + val speculationEnabled = sparkConf.getBoolean("spark.speculation", false) + // This (undocumented) setting is an escape-hatch in case the commit code introduces bugs + sparkConf.getBoolean("spark.hadoop.outputCommitCoordination.enabled", speculationEnabled) + } + if (shouldCoordinateWithDriver) { + val outputCommitCoordinator = SparkEnv.get.outputCommitCoordinator + val canCommit = outputCommitCoordinator.canCommit(jobID, splitID, attemptID) + if (canCommit) { + performCommit() + } else { + val msg = s"$taID: Not committed because the driver did not authorize commit" + logInfo(msg) + // We need to abort the task so that the driver can reschedule new attempts, if necessary + cmtr.abortTask(taCtxt) + throw new CommitDeniedException(msg, jobID, splitID, attemptID) } + } else { + // Speculation is disabled or a user has chosen to manually bypass the commit coordination + performCommit() } } else { - logInfo ("No need to commit output of task: " + taID.value) + // Some other attempt committed the output, so we do nothing and signal success + logInfo(s"No need to commit output of task because needsTaskCommit=false: ${taID.value}") } } def commitJob() { - // always ? Or if cmtr.needsTaskCommit ? val cmtr = getOutputCommitter() cmtr.commitJob(getJobContext()) } diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala new file mode 100644 index 0000000000000..7d7fe1a446313 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/TaskContext.scala @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import java.io.Serializable + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.executor.TaskMetrics +import org.apache.spark.util.TaskCompletionListener + + +object TaskContext { + /** + * Return the currently active TaskContext. This can be called inside of + * user functions to access contextual information about running tasks. + */ + def get(): TaskContext = taskContext.get + + private val taskContext: ThreadLocal[TaskContext] = new ThreadLocal[TaskContext] + + // Note: protected[spark] instead of private[spark] to prevent the following two from + // showing up in JavaDoc. + /** + * Set the thread local TaskContext. Internal to Spark. + */ + protected[spark] def setTaskContext(tc: TaskContext): Unit = taskContext.set(tc) + + /** + * Unset the thread local TaskContext. Internal to Spark. + */ + protected[spark] def unset(): Unit = taskContext.remove() +} + + +/** + * Contextual information about a task which can be read or mutated during + * execution. To access the TaskContext for a running task, use: + * {{{ + * org.apache.spark.TaskContext.get() + * }}} + */ +abstract class TaskContext extends Serializable { + // Note: TaskContext must NOT define a get method. Otherwise it will prevent the Scala compiler + // from generating a static get method (based on the companion object's get method). + + // Note: Update JavaTaskContextCompileCheck when new methods are added to this class. + + // Note: getters in this class are defined with parentheses to maintain backward compatibility. + + /** + * Returns true if the task has completed. + */ + def isCompleted(): Boolean + + /** + * Returns true if the task has been killed. + */ + def isInterrupted(): Boolean + + @deprecated("use isRunningLocally", "1.2.0") + def runningLocally(): Boolean + + /** + * Returns true if the task is running locally in the driver program. + * @return + */ + def isRunningLocally(): Boolean + + /** + * Adds a (Java friendly) listener to be executed on task completion. + * This will be called in all situation - success, failure, or cancellation. + * An example use is for HadoopRDD to register a callback to close the input stream. + */ + def addTaskCompletionListener(listener: TaskCompletionListener): TaskContext + + /** + * Adds a listener in the form of a Scala closure to be executed on task completion. + * This will be called in all situations - success, failure, or cancellation. + * An example use is for HadoopRDD to register a callback to close the input stream. + */ + def addTaskCompletionListener(f: (TaskContext) => Unit): TaskContext + + /** + * Adds a callback function to be executed on task completion. An example use + * is for HadoopRDD to register a callback to close the input stream. + * Will be called in any situation - success, failure, or cancellation. + * + * @param f Callback function. + */ + @deprecated("use addTaskCompletionListener", "1.2.0") + def addOnCompleteCallback(f: () => Unit) + + /** + * The ID of the stage that this task belong to. + */ + def stageId(): Int + + /** + * The ID of the RDD partition that is computed by this task. + */ + def partitionId(): Int + + /** + * How many times this task has been attempted. The first task attempt will be assigned + * attemptNumber = 0, and subsequent attempts will have increasing attempt numbers. + */ + def attemptNumber(): Int + + @deprecated("use attemptNumber", "1.3.0") + def attemptId(): Long + + /** + * An ID that is unique to this task attempt (within the same SparkContext, no two task attempts + * will share the same attempt ID). This is roughly equivalent to Hadoop's TaskAttemptID. + */ + def taskAttemptId(): Long + + /** ::DeveloperApi:: */ + @DeveloperApi + def taskMetrics(): TaskMetrics +} diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala index 9bb0c61e441f8..337c8e4ebebcd 100644 --- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala +++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala @@ -33,7 +33,7 @@ private[spark] class TaskContextImpl( with Logging { // For backwards-compatibility; this method is now deprecated as of 1.3.0. - override def attemptId: Long = taskAttemptId + override def attemptId(): Long = taskAttemptId // List of callback functions to execute when the task completes. @transient private val onCompleteCallbacks = new ArrayBuffer[TaskCompletionListener] @@ -87,10 +87,10 @@ private[spark] class TaskContextImpl( interrupted = true } - override def isCompleted: Boolean = completed + override def isCompleted(): Boolean = completed - override def isRunningLocally: Boolean = runningLocally + override def isRunningLocally(): Boolean = runningLocally - override def isInterrupted: Boolean = interrupted + override def isInterrupted(): Boolean = interrupted } diff --git a/core/src/main/scala/org/apache/spark/TaskEndReason.scala b/core/src/main/scala/org/apache/spark/TaskEndReason.scala index af5fd8e0ac00c..29a5cd5fdac76 100644 --- a/core/src/main/scala/org/apache/spark/TaskEndReason.scala +++ b/core/src/main/scala/org/apache/spark/TaskEndReason.scala @@ -146,6 +146,20 @@ case object TaskKilled extends TaskFailedReason { override def toErrorString: String = "TaskKilled (killed intentionally)" } +/** + * :: DeveloperApi :: + * Task requested the driver to commit, but was denied. + */ +@DeveloperApi +case class TaskCommitDenied( + jobID: Int, + partitionID: Int, + attemptID: Int) + extends TaskFailedReason { + override def toErrorString: String = s"TaskCommitDenied (Driver denied task commit)" + + s" for job: $jobID, partition: $partitionID, attempt: $attemptID" +} + /** * :: DeveloperApi :: * The task failed because the executor that it was running on was lost. This may happen because diff --git a/core/src/main/scala/org/apache/spark/TestUtils.scala b/core/src/main/scala/org/apache/spark/TestUtils.scala index 34078142f5385..35b324ba6f573 100644 --- a/core/src/main/scala/org/apache/spark/TestUtils.scala +++ b/core/src/main/scala/org/apache/spark/TestUtils.scala @@ -17,12 +17,13 @@ package org.apache.spark -import java.io.{File, FileInputStream, FileOutputStream} +import java.io.{ByteArrayInputStream, File, FileInputStream, FileOutputStream} import java.net.{URI, URL} import java.util.jar.{JarEntry, JarOutputStream} import scala.collection.JavaConversions._ +import com.google.common.base.Charsets.UTF_8 import com.google.common.io.{ByteStreams, Files} import javax.tools.{JavaFileObject, SimpleJavaFileObject, ToolProvider} @@ -43,13 +44,38 @@ private[spark] object TestUtils { * Note: if this is used during class loader tests, class names should be unique * in order to avoid interference between tests. */ - def createJarWithClasses(classNames: Seq[String], value: String = ""): URL = { + def createJarWithClasses( + classNames: Seq[String], + toStringValue: String = "", + classNamesWithBase: Seq[(String, String)] = Seq(), + classpathUrls: Seq[URL] = Seq()): URL = { val tempDir = Utils.createTempDir() - val files = for (name <- classNames) yield createCompiledClass(name, tempDir, value) + val files1 = for (name <- classNames) yield { + createCompiledClass(name, tempDir, toStringValue, classpathUrls = classpathUrls) + } + val files2 = for ((childName, baseName) <- classNamesWithBase) yield { + createCompiledClass(childName, tempDir, toStringValue, baseName, classpathUrls) + } val jarFile = new File(tempDir, "testJar-%s.jar".format(System.currentTimeMillis())) - createJar(files, jarFile) + createJar(files1 ++ files2, jarFile) } + /** + * Create a jar file containing multiple files. The `files` map contains a mapping of + * file names in the jar file to their contents. + */ + def createJarWithFiles(files: Map[String, String], dir: File = null): URL = { + val tempDir = Option(dir).getOrElse(Utils.createTempDir()) + val jarFile = File.createTempFile("testJar", ".jar", tempDir) + val jarStream = new JarOutputStream(new FileOutputStream(jarFile)) + files.foreach { case (k, v) => + val entry = new JarEntry(k) + jarStream.putNextEntry(entry) + ByteStreams.copy(new ByteArrayInputStream(v.getBytes(UTF_8)), jarStream) + } + jarStream.close() + jarFile.toURI.toURL + } /** * Create a jar file that contains this set of files. All files will be located at the root @@ -85,15 +111,26 @@ private[spark] object TestUtils { } /** Creates a compiled class with the given name. Class file will be placed in destDir. */ - def createCompiledClass(className: String, destDir: File, value: String = ""): File = { + def createCompiledClass( + className: String, + destDir: File, + toStringValue: String = "", + baseClass: String = null, + classpathUrls: Seq[URL] = Seq()): File = { val compiler = ToolProvider.getSystemJavaCompiler + val extendsText = Option(baseClass).map { c => s" extends ${c}" }.getOrElse("") val sourceFile = new JavaSourceFromString(className, - "public class " + className + " implements java.io.Serializable {" + - " @Override public String toString() { return \"" + value + "\"; }}") + "public class " + className + extendsText + " implements java.io.Serializable {" + + " @Override public String toString() { return \"" + toStringValue + "\"; }}") // Calling this outputs a class file in pwd. It's easier to just rename the file than // build a custom FileManager that controls the output location. - compiler.getTask(null, null, null, null, null, Seq(sourceFile)).call() + val options = if (classpathUrls.nonEmpty) { + Seq("-classpath", classpathUrls.map { _.getFile }.mkString(File.pathSeparator)) + } else { + Seq() + } + compiler.getTask(null, null, null, options, null, Seq(sourceFile)).call() val fileName = className + ".class" val result = new File(fileName) diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala index bd451634e53d2..0f91c942ecd50 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala @@ -38,6 +38,10 @@ import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils +/** + * Defines operations common to several Java RDD implementations. + * Note that this trait is not intended to be implemented by user code. + */ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { def wrapRDD(rdd: RDD[T]): This @@ -344,6 +348,19 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { */ def reduce(f: JFunction2[T, T, T]): T = rdd.reduce(f) + /** + * Reduces the elements of this RDD in a multi-level tree pattern. + * + * @param depth suggested depth of the tree + * @see [[org.apache.spark.api.java.JavaRDDLike#reduce]] + */ + def treeReduce(f: JFunction2[T, T, T], depth: Int): T = rdd.treeReduce(f, depth) + + /** + * [[org.apache.spark.api.java.JavaRDDLike#treeReduce]] with suggested depth 2. + */ + def treeReduce(f: JFunction2[T, T, T]): T = treeReduce(f, 2) + /** * Aggregate the elements of each partition, and then the results for all the partitions, using a * given associative function and a neutral "zero value". The function op(t1, t2) is allowed to @@ -365,6 +382,30 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { combOp: JFunction2[U, U, U]): U = rdd.aggregate(zeroValue)(seqOp, combOp)(fakeClassTag[U]) + /** + * Aggregates the elements of this RDD in a multi-level tree pattern. + * + * @param depth suggested depth of the tree + * @see [[org.apache.spark.api.java.JavaRDDLike#aggregate]] + */ + def treeAggregate[U]( + zeroValue: U, + seqOp: JFunction2[U, T, U], + combOp: JFunction2[U, U, U], + depth: Int): U = { + rdd.treeAggregate(zeroValue)(seqOp, combOp, depth)(fakeClassTag[U]) + } + + /** + * [[org.apache.spark.api.java.JavaRDDLike#treeAggregate]] with suggested depth 2. + */ + def treeAggregate[U]( + zeroValue: U, + seqOp: JFunction2[U, T, U], + combOp: JFunction2[U, U, U]): U = { + treeAggregate(zeroValue, seqOp, combOp, 2) + } + /** * Return the number of elements in the RDD. */ @@ -435,6 +476,12 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { */ def first(): T = rdd.first() + /** + * @return true if and only if the RDD contains no elements at all. Note that an RDD + * may be empty even when it has at least 1 partition. + */ + def isEmpty(): Boolean = rdd.isEmpty() + /** * Save this RDD as a text file, using string representations of elements. */ diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala index 97f5c9f257e09..6d6ed693be752 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala @@ -373,6 +373,15 @@ class JavaSparkContext(val sc: SparkContext) * other necessary info (e.g. file name for a filesystem-based dataset, table name for HyperTable, * etc). * + * @param conf JobConf for setting up the dataset. Note: This will be put into a Broadcast. + * Therefore if you plan to reuse this conf to create multiple RDDs, you need to make + * sure you won't modify the conf. A safe approach is always creating a new conf for + * a new RDD. + * @param inputFormatClass Class of the InputFormat + * @param keyClass Class of the keys + * @param valueClass Class of the values + * @param minPartitions Minimum number of Hadoop Splits to generate. + * * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each * record, directly caching the returned RDD will create many references to the same object. * If you plan to directly cache Hadoop writable objects, you should first copy them using @@ -395,6 +404,14 @@ class JavaSparkContext(val sc: SparkContext) * Get an RDD for a Hadoop-readable dataset from a Hadooop JobConf giving its InputFormat and any * other necessary info (e.g. file name for a filesystem-based dataset, table name for HyperTable, * + * @param conf JobConf for setting up the dataset. Note: This will be put into a Broadcast. + * Therefore if you plan to reuse this conf to create multiple RDDs, you need to make + * sure you won't modify the conf. A safe approach is always creating a new conf for + * a new RDD. + * @param inputFormatClass Class of the InputFormat + * @param keyClass Class of the keys + * @param valueClass Class of the values + * * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each * record, directly caching the returned RDD will create many references to the same object. * If you plan to directly cache Hadoop writable objects, you should first copy them using @@ -476,6 +493,14 @@ class JavaSparkContext(val sc: SparkContext) * Get an RDD for a given Hadoop file with an arbitrary new API InputFormat * and extra configuration options to pass to the input format. * + * @param conf Configuration for setting up the dataset. Note: This will be put into a Broadcast. + * Therefore if you plan to reuse this conf to create multiple RDDs, you need to make + * sure you won't modify the conf. A safe approach is always creating a new conf for + * a new RDD. + * @param fClass Class of the InputFormat + * @param kClass Class of the keys + * @param vClass Class of the values + * * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each * record, directly caching the returned RDD will create many references to the same object. * If you plan to directly cache Hadoop writable objects, you should first copy them using @@ -675,6 +700,9 @@ class JavaSparkContext(val sc: SparkContext) /** * Returns the Hadoop configuration used for the Hadoop code (e.g. file systems) we reuse. + * + * '''Note:''' As it will be reused in all Hadoop RDDs, it's better not to modify it unless you + * plan to set some global configurations for all Hadoop RDDs. */ def hadoopConfiguration(): Configuration = { sc.hadoopConfiguration diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonGatewayServer.scala b/core/src/main/scala/org/apache/spark/api/python/PythonGatewayServer.scala new file mode 100644 index 0000000000000..164e95081583f --- /dev/null +++ b/core/src/main/scala/org/apache/spark/api/python/PythonGatewayServer.scala @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.python + +import java.io.DataOutputStream +import java.net.Socket + +import py4j.GatewayServer + +import org.apache.spark.Logging +import org.apache.spark.util.Utils + +/** + * Process that starts a Py4J GatewayServer on an ephemeral port and communicates the bound port + * back to its caller via a callback port specified by the caller. + * + * This process is launched (via SparkSubmit) by the PySpark driver (see java_gateway.py). + */ +private[spark] object PythonGatewayServer extends Logging { + def main(args: Array[String]): Unit = Utils.tryOrExit { + // Start a GatewayServer on an ephemeral port + val gatewayServer: GatewayServer = new GatewayServer(null, 0) + gatewayServer.start() + val boundPort: Int = gatewayServer.getListeningPort + if (boundPort == -1) { + logError("GatewayServer failed to bind; exiting") + System.exit(1) + } else { + logDebug(s"Started PythonGatewayServer on port $boundPort") + } + + // Communicate the bound port back to the caller via the caller-specified callback port + val callbackHost = sys.env("_PYSPARK_DRIVER_CALLBACK_HOST") + val callbackPort = sys.env("_PYSPARK_DRIVER_CALLBACK_PORT").toInt + logDebug(s"Communicating GatewayServer port to Python driver at $callbackHost:$callbackPort") + val callbackSocket = new Socket(callbackHost, callbackPort) + val dos = new DataOutputStream(callbackSocket.getOutputStream) + dos.writeInt(boundPort) + dos.close() + callbackSocket.close() + + // Exit on EOF or broken pipe to ensure that this process dies when the Python driver dies: + while (System.in.read() != -1) { + // Do nothing + } + logDebug("Exiting due to broken pipe from Python driver") + System.exit(0) + } +} diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala b/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala index 5ba66178e2b78..c9181a29d4756 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala @@ -138,6 +138,11 @@ private[python] class JavaToWritableConverter extends Converter[Any, Writable] { mapWritable.put(convertToWritable(k), convertToWritable(v)) } mapWritable + case array: Array[Any] => { + val arrayWriteable = new ArrayWritable(classOf[Writable]) + arrayWriteable.set(array.map(convertToWritable(_))) + arrayWriteable + } case other => throw new SparkException( s"Data of type ${other.getClass.getName} cannot be used") } diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index bad40e6529f74..dcb6e6313a1d2 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -67,17 +67,16 @@ private[spark] class PythonRDD( envVars += ("SPARK_REUSE_WORKER" -> "1") } val worker: Socket = env.createPythonWorker(pythonExec, envVars.toMap) + // Whether is the worker released into idle pool + @volatile var released = false // Start a thread to feed the process input from our parent's iterator val writerThread = new WriterThread(env, worker, split, context) - var complete_cleanly = false context.addTaskCompletionListener { context => writerThread.shutdownOnTaskCompletion() writerThread.join() - if (reuse_worker && complete_cleanly) { - env.releasePythonWorker(pythonExec, envVars.toMap, worker) - } else { + if (!reuse_worker || !released) { try { worker.close() } catch { @@ -125,8 +124,8 @@ private[spark] class PythonRDD( init, finish)) val memoryBytesSpilled = stream.readLong() val diskBytesSpilled = stream.readLong() - context.taskMetrics.memoryBytesSpilled += memoryBytesSpilled - context.taskMetrics.diskBytesSpilled += diskBytesSpilled + context.taskMetrics.incMemoryBytesSpilled(memoryBytesSpilled) + context.taskMetrics.incDiskBytesSpilled(diskBytesSpilled) read() case SpecialLengths.PYTHON_EXCEPTION_THROWN => // Signals that an exception has been thrown in python @@ -145,8 +144,12 @@ private[spark] class PythonRDD( stream.readFully(update) accumulator += Collections.singletonList(update) } + // Check whether the worker is ready to be re-used. if (stream.readInt() == SpecialLengths.END_OF_STREAM) { - complete_cleanly = true + if (reuse_worker) { + env.releasePythonWorker(pythonExec, envVars.toMap, worker) + released = true + } } null } @@ -245,13 +248,13 @@ private[spark] class PythonRDD( } catch { case e: Exception if context.isCompleted || context.isInterrupted => logDebug("Exception thrown after task completion (likely due to cleanup)", e) - worker.shutdownOutput() + Utils.tryLog(worker.shutdownOutput()) case e: Exception => // We must avoid throwing exceptions here, because the thread uncaught exception handler // will kill the whole executor (see org.apache.spark.executor.Executor). _exception = e - worker.shutdownOutput() + Utils.tryLog(worker.shutdownOutput()) } finally { // Release memory used by this thread for shuffles env.shuffleMemoryManager.releaseMemoryForThisThread() @@ -300,6 +303,7 @@ private class PythonException(msg: String, cause: Exception) extends RuntimeExce private class PairwiseRDD(prev: RDD[Array[Byte]]) extends RDD[(Long, Array[Byte])](prev) { override def getPartitions = prev.partitions + override val partitioner = prev.partitioner override def compute(split: Partition, context: TaskContext) = prev.iterator(split, context).grouped(2).map { case Seq(a, b) => (Utils.deserializeLongValue(a), b) @@ -313,6 +317,7 @@ private object SpecialLengths { val PYTHON_EXCEPTION_THROWN = -2 val TIMING_DATA = -3 val END_OF_STREAM = -4 + val NULL = -5 } private[spark] object PythonRDD extends Logging { @@ -325,6 +330,15 @@ private[spark] object PythonRDD extends Logging { } } + /** + * Return an RDD of values from an RDD of (Long, Array[Byte]), with preservePartitions=true + * + * This is useful for PySpark to have the partitioner after partitionBy() + */ + def valueOfPair(pair: JavaPairRDD[Long, Array[Byte]]): JavaRDD[Array[Byte]] = { + pair.rdd.mapPartitions(it => it.map(_._2), true) + } + /** * Adapter for calling SparkContext#runJob from Python. * @@ -371,54 +385,25 @@ private[spark] object PythonRDD extends Logging { } def writeIteratorToStream[T](iter: Iterator[T], dataOut: DataOutputStream) { - // The right way to implement this would be to use TypeTags to get the full - // type of T. Since I don't want to introduce breaking changes throughout the - // entire Spark API, I have to use this hacky approach: - if (iter.hasNext) { - val first = iter.next() - val newIter = Seq(first).iterator ++ iter - first match { - case arr: Array[Byte] => - newIter.asInstanceOf[Iterator[Array[Byte]]].foreach { bytes => - dataOut.writeInt(bytes.length) - dataOut.write(bytes) - } - case string: String => - newIter.asInstanceOf[Iterator[String]].foreach { str => - writeUTF(str, dataOut) - } - case stream: PortableDataStream => - newIter.asInstanceOf[Iterator[PortableDataStream]].foreach { stream => - val bytes = stream.toArray() - dataOut.writeInt(bytes.length) - dataOut.write(bytes) - } - case (key: String, stream: PortableDataStream) => - newIter.asInstanceOf[Iterator[(String, PortableDataStream)]].foreach { - case (key, stream) => - writeUTF(key, dataOut) - val bytes = stream.toArray() - dataOut.writeInt(bytes.length) - dataOut.write(bytes) - } - case (key: String, value: String) => - newIter.asInstanceOf[Iterator[(String, String)]].foreach { - case (key, value) => - writeUTF(key, dataOut) - writeUTF(value, dataOut) - } - case (key: Array[Byte], value: Array[Byte]) => - newIter.asInstanceOf[Iterator[(Array[Byte], Array[Byte])]].foreach { - case (key, value) => - dataOut.writeInt(key.length) - dataOut.write(key) - dataOut.writeInt(value.length) - dataOut.write(value) - } - case other => - throw new SparkException("Unexpected element type " + first.getClass) - } + + def write(obj: Any): Unit = obj match { + case null => + dataOut.writeInt(SpecialLengths.NULL) + case arr: Array[Byte] => + dataOut.writeInt(arr.length) + dataOut.write(arr) + case str: String => + writeUTF(str, dataOut) + case stream: PortableDataStream => + write(stream.toArray()) + case (key, value) => + write(key) + write(value) + case other => + throw new SparkException("Unexpected element type " + other.getClass) } + + iter.foreach(write) } /** diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala b/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala index be5ebfa9219d3..acbaba6791850 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala @@ -17,11 +17,14 @@ package org.apache.spark.api.python -import java.io.{File, InputStream, IOException, OutputStream} +import java.io.{File} +import java.util.{List => JList} +import scala.collection.JavaConversions._ import scala.collection.mutable.ArrayBuffer import org.apache.spark.SparkContext +import org.apache.spark.api.java.{JavaSparkContext, JavaRDD} private[spark] object PythonUtils { /** Get the PYTHONPATH for PySpark, either from SPARK_HOME, if it is set, or from our JAR */ @@ -39,4 +42,15 @@ private[spark] object PythonUtils { def mergePythonPaths(paths: String*): String = { paths.filter(_ != "").mkString(File.pathSeparator) } + + def generateRDDWithNull(sc: JavaSparkContext): JavaRDD[String] = { + sc.parallelize(List("a", null, "b")) + } + + /** + * Convert list of T into seq of T (for calling API with varargs) + */ + def toSeq[T](cols: JList[T]): Seq[T] = { + cols.toList.toSeq + } } diff --git a/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala b/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala index a4153aaa926f8..fb52a960e0765 100644 --- a/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala +++ b/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala @@ -153,7 +153,10 @@ private[spark] object SerDeUtil extends Logging { iter.flatMap { row => val obj = unpickle.loads(row) if (batched) { - obj.asInstanceOf[JArrayList[_]].asScala + obj match { + case array: Array[Any] => array.toSeq + case _ => obj.asInstanceOf[JArrayList[_]].asScala + } } else { Seq(obj) } @@ -199,7 +202,10 @@ private[spark] object SerDeUtil extends Logging { * representation is serialized */ def pairRDDToPython(rdd: RDD[(Any, Any)], batchSize: Int): RDD[Array[Byte]] = { - val (keyFailed, valueFailed) = checkPickle(rdd.first()) + val (keyFailed, valueFailed) = rdd.take(1) match { + case Array() => (false, false) + case Array(first) => checkPickle(first) + } rdd.mapPartitions { iter => val cleaned = iter.map { case (k, v) => @@ -226,10 +232,12 @@ private[spark] object SerDeUtil extends Logging { } val rdd = pythonToJava(pyRDD, batched).rdd - rdd.first match { - case obj if isPair(obj) => + rdd.take(1) match { + case Array(obj) if isPair(obj) => // we only accept (K, V) - case other => throw new SparkException( + case Array() => + // we also accept empty collections + case Array(other) => throw new SparkException( s"RDD element of type ${other.getClass.getName} cannot be used") } rdd.map { obj => diff --git a/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala b/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala index c0cbd28a845be..cf289fb3ae39f 100644 --- a/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala +++ b/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala @@ -107,7 +107,6 @@ private[python] class WritableToDoubleArrayConverter extends Converter[Any, Arra * given directory (probably a temp directory) */ object WriteInputFormatTestDataGenerator { - import SparkContext._ def main(args: Array[String]) { val path = args(0) diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala index 31d6958c403b3..1444c0dd3d2d6 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala @@ -151,7 +151,7 @@ private[broadcast] object HttpBroadcast extends Logging { } private def createServer(conf: SparkConf) { - broadcastDir = Utils.createTempDir(Utils.getLocalDir(conf)) + broadcastDir = Utils.createTempDir(Utils.getLocalDir(conf), "broadcast") val broadcastPort = conf.getInt("spark.broadcast.port", 0) server = new HttpServer(conf, broadcastDir, securityManager, broadcastPort, "HTTP broadcast server") @@ -199,6 +199,7 @@ private[broadcast] object HttpBroadcast extends Logging { uc = new URL(url).openConnection() uc.setConnectTimeout(httpReadTimeout) } + Utils.setupSecureURLConnection(uc, securityManager) val in = { uc.setReadTimeout(httpReadTimeout) diff --git a/core/src/main/scala/org/apache/spark/deploy/ApplicationDescription.scala b/core/src/main/scala/org/apache/spark/deploy/ApplicationDescription.scala index 65a1a8fd7e929..ae55b4ff40b74 100644 --- a/core/src/main/scala/org/apache/spark/deploy/ApplicationDescription.scala +++ b/core/src/main/scala/org/apache/spark/deploy/ApplicationDescription.scala @@ -28,5 +28,14 @@ private[spark] class ApplicationDescription( val user = System.getProperty("user.name", "") + def copy( + name: String = name, + maxCores: Option[Int] = maxCores, + memoryPerSlave: Int = memoryPerSlave, + command: Command = command, + appUiUrl: String = appUiUrl, + eventLogDir: Option[String] = eventLogDir): ApplicationDescription = + new ApplicationDescription(name, maxCores, memoryPerSlave, command, appUiUrl, eventLogDir) + override def toString: String = "ApplicationDescription(" + name + ")" } diff --git a/core/src/main/scala/org/apache/spark/deploy/Client.scala b/core/src/main/scala/org/apache/spark/deploy/Client.scala index 7c1c831c248fc..237d26fc6bd0e 100644 --- a/core/src/main/scala/org/apache/spark/deploy/Client.scala +++ b/core/src/main/scala/org/apache/spark/deploy/Client.scala @@ -39,7 +39,8 @@ private class ClientActor(driverArgs: ClientArguments, conf: SparkConf) val timeout = AkkaUtils.askTimeout(conf) override def preStart() = { - masterActor = context.actorSelection(Master.toAkkaUrl(driverArgs.master)) + masterActor = context.actorSelection( + Master.toAkkaUrl(driverArgs.master, AkkaUtils.protocol(context.system))) context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) @@ -67,8 +68,9 @@ private class ClientActor(driverArgs: ClientArguments, conf: SparkConf) .map(Utils.splitCommandString).getOrElse(Seq.empty) val sparkJavaOpts = Utils.sparkJavaOpts(conf) val javaOpts = sparkJavaOpts ++ extraJavaOpts - val command = new Command(mainClass, Seq("{{WORKER_URL}}", driverArgs.mainClass) ++ - driverArgs.driverOptions, sys.env, classPathEntries, libraryPathEntries, javaOpts) + val command = new Command(mainClass, + Seq("{{WORKER_URL}}", "{{USER_JAR}}", driverArgs.mainClass) ++ driverArgs.driverOptions, + sys.env, classPathEntries, libraryPathEntries, javaOpts) val driverDescription = new DriverDescription( driverArgs.jarUrl, @@ -161,7 +163,7 @@ object Client { "driverClient", Utils.localHostName(), 0, conf, new SecurityManager(conf)) // Verify driverArgs.master is a valid url so that we can use it in ClientActor safely - Master.toAkkaUrl(driverArgs.master) + Master.toAkkaUrl(driverArgs.master, AkkaUtils.protocol(actorSystem)) actorSystem.actorOf(Props(classOf[ClientActor], driverArgs, conf)) actorSystem.awaitTermination() diff --git a/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala b/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala index 2e1e52906ceeb..415bd50591692 100644 --- a/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala @@ -23,14 +23,13 @@ import scala.collection.mutable.ListBuffer import org.apache.log4j.Level -import org.apache.spark.util.MemoryParam +import org.apache.spark.util.{IntParam, MemoryParam} /** * Command-line parser for the driver client. */ private[spark] class ClientArguments(args: Array[String]) { - val defaultCores = 1 - val defaultMemory = 512 + import ClientArguments._ var cmd: String = "" // 'launch' or 'kill' var logLevel = Level.WARN @@ -39,9 +38,9 @@ private[spark] class ClientArguments(args: Array[String]) { var master: String = "" var jarUrl: String = "" var mainClass: String = "" - var supervise: Boolean = false - var memory: Int = defaultMemory - var cores: Int = defaultCores + var supervise: Boolean = DEFAULT_SUPERVISE + var memory: Int = DEFAULT_MEMORY + var cores: Int = DEFAULT_CORES private var _driverOptions = ListBuffer[String]() def driverOptions = _driverOptions.toSeq @@ -50,9 +49,9 @@ private[spark] class ClientArguments(args: Array[String]) { parse(args.toList) - def parse(args: List[String]): Unit = args match { - case ("--cores" | "-c") :: value :: tail => - cores = value.toInt + private def parse(args: List[String]): Unit = args match { + case ("--cores" | "-c") :: IntParam(value) :: tail => + cores = value parse(tail) case ("--memory" | "-m") :: MemoryParam(value) :: tail => @@ -106,9 +105,10 @@ private[spark] class ClientArguments(args: Array[String]) { |Usage: DriverClient kill | |Options: - | -c CORES, --cores CORES Number of cores to request (default: $defaultCores) - | -m MEMORY, --memory MEMORY Megabytes of memory to request (default: $defaultMemory) + | -c CORES, --cores CORES Number of cores to request (default: $DEFAULT_CORES) + | -m MEMORY, --memory MEMORY Megabytes of memory to request (default: $DEFAULT_MEMORY) | -s, --supervise Whether to restart the driver on failure + | (default: $DEFAULT_SUPERVISE) | -v, --verbose Print more debugging output """.stripMargin System.err.println(usage) @@ -117,6 +117,10 @@ private[spark] class ClientArguments(args: Array[String]) { } object ClientArguments { + private[spark] val DEFAULT_CORES = 1 + private[spark] val DEFAULT_MEMORY = 512 // MB + private[spark] val DEFAULT_SUPERVISE = false + def isValidJarUrl(s: String): Boolean = { try { val uri = new URI(s) diff --git a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala index 243d8edb72ed3..7f600d89604a2 100644 --- a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala @@ -148,15 +148,22 @@ private[deploy] object DeployMessages { // Master to MasterWebUI - case class MasterStateResponse(host: String, port: Int, workers: Array[WorkerInfo], - activeApps: Array[ApplicationInfo], completedApps: Array[ApplicationInfo], - activeDrivers: Array[DriverInfo], completedDrivers: Array[DriverInfo], - status: MasterState) { + case class MasterStateResponse( + host: String, + port: Int, + restPort: Option[Int], + workers: Array[WorkerInfo], + activeApps: Array[ApplicationInfo], + completedApps: Array[ApplicationInfo], + activeDrivers: Array[DriverInfo], + completedDrivers: Array[DriverInfo], + status: MasterState) { Utils.checkHost(host, "Required hostname") assert (port > 0) def uri = "spark://" + host + ":" + port + def restUri: Option[String] = restPort.map { p => "spark://" + host + ":" + p } } // WorkerWebUI to Worker diff --git a/core/src/main/scala/org/apache/spark/deploy/DriverDescription.scala b/core/src/main/scala/org/apache/spark/deploy/DriverDescription.scala index 58c95dc4f9116..b056a19ce6598 100644 --- a/core/src/main/scala/org/apache/spark/deploy/DriverDescription.scala +++ b/core/src/main/scala/org/apache/spark/deploy/DriverDescription.scala @@ -25,5 +25,13 @@ private[spark] class DriverDescription( val command: Command) extends Serializable { + def copy( + jarUrl: String = jarUrl, + mem: Int = mem, + cores: Int = cores, + supervise: Boolean = supervise, + command: Command = command): DriverDescription = + new DriverDescription(jarUrl, mem, cores, supervise, command) + override def toString: String = s"DriverDescription (${command.mainClass})" } diff --git a/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala b/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala index 9a7a113c95715..0401b15446a7b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala +++ b/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala @@ -33,7 +33,11 @@ import org.apache.spark.util.Utils * fault recovery without spinning up a lot of processes. */ private[spark] -class LocalSparkCluster(numWorkers: Int, coresPerWorker: Int, memoryPerWorker: Int) +class LocalSparkCluster( + numWorkers: Int, + coresPerWorker: Int, + memoryPerWorker: Int, + conf: SparkConf) extends Logging { private val localHostname = Utils.localHostName() @@ -43,9 +47,11 @@ class LocalSparkCluster(numWorkers: Int, coresPerWorker: Int, memoryPerWorker: I def start(): Array[String] = { logInfo("Starting a local Spark cluster with " + numWorkers + " workers.") + // Disable REST server on Master in this mode unless otherwise specified + val _conf = conf.clone().setIfMissing("spark.master.rest.enabled", "false") + /* Start the Master */ - val conf = new SparkConf(false) - val (masterSystem, masterPort, _) = Master.startSystemAndActor(localHostname, 0, 0, conf) + val (masterSystem, masterPort, _, _) = Master.startSystemAndActor(localHostname, 0, 0, _conf) masterActorSystems += masterSystem val masterUrl = "spark://" + localHostname + ":" + masterPort val masters = Array(masterUrl) diff --git a/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala b/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala index 039c8719e2867..53e18c4bcec23 100644 --- a/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala @@ -26,7 +26,7 @@ import org.apache.spark.api.python.PythonUtils import org.apache.spark.util.{RedirectThread, Utils} /** - * A main class used by spark-submit to launch Python applications. It executes python as a + * A main class used to launch Python applications. It executes python as a * subprocess and then has it connect back to the JVM to access system properties, etc. */ object PythonRunner { diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala index 57f9faf5ddd1d..e0a32fb65cd51 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala @@ -21,7 +21,7 @@ import java.lang.reflect.Method import java.security.PrivilegedExceptionAction import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.hadoop.fs.{FileStatus, FileSystem, Path} import org.apache.hadoop.fs.FileSystem.Statistics import org.apache.hadoop.mapred.JobConf import org.apache.hadoop.mapreduce.{JobContext, TaskAttemptContext} @@ -52,18 +52,13 @@ class SparkHadoopUtil extends Logging { * do a FileSystem.closeAllForUGI in order to avoid leaking Filesystems */ def runAsSparkUser(func: () => Unit) { - val user = Option(System.getenv("SPARK_USER")).getOrElse(SparkContext.SPARK_UNKNOWN_USER) - if (user != SparkContext.SPARK_UNKNOWN_USER) { - logDebug("running as user: " + user) - val ugi = UserGroupInformation.createRemoteUser(user) - transferCredentials(UserGroupInformation.getCurrentUser(), ugi) - ugi.doAs(new PrivilegedExceptionAction[Unit] { - def run: Unit = func() - }) - } else { - logDebug("running as SPARK_UNKNOWN_USER") - func() - } + val user = Utils.getCurrentUserName() + logDebug("running as user: " + user) + val ugi = UserGroupInformation.createRemoteUser(user) + transferCredentials(UserGroupInformation.getCurrentUser(), ugi) + ugi.doAs(new PrivilegedExceptionAction[Unit] { + def run: Unit = func() + }) } def transferCredentials(source: UserGroupInformation, dest: UserGroupInformation) { @@ -133,16 +128,15 @@ class SparkHadoopUtil extends Logging { * statistics are only available as of Hadoop 2.5 (see HADOOP-10688). * Returns None if the required method can't be found. */ - private[spark] def getFSBytesReadOnThreadCallback(path: Path, conf: Configuration) - : Option[() => Long] = { + private[spark] def getFSBytesReadOnThreadCallback(): Option[() => Long] = { try { - val threadStats = getFileSystemThreadStatistics(path, conf) + val threadStats = getFileSystemThreadStatistics() val getBytesReadMethod = getFileSystemThreadStatisticsMethod("getBytesRead") val f = () => threadStats.map(getBytesReadMethod.invoke(_).asInstanceOf[Long]).sum val baselineBytesRead = f() Some(() => f() - baselineBytesRead) } catch { - case e: NoSuchMethodException => { + case e @ (_: NoSuchMethodException | _: ClassNotFoundException) => { logDebug("Couldn't find method for retrieving thread-level FileSystem input data", e) None } @@ -156,26 +150,23 @@ class SparkHadoopUtil extends Logging { * statistics are only available as of Hadoop 2.5 (see HADOOP-10688). * Returns None if the required method can't be found. */ - private[spark] def getFSBytesWrittenOnThreadCallback(path: Path, conf: Configuration) - : Option[() => Long] = { + private[spark] def getFSBytesWrittenOnThreadCallback(): Option[() => Long] = { try { - val threadStats = getFileSystemThreadStatistics(path, conf) + val threadStats = getFileSystemThreadStatistics() val getBytesWrittenMethod = getFileSystemThreadStatisticsMethod("getBytesWritten") val f = () => threadStats.map(getBytesWrittenMethod.invoke(_).asInstanceOf[Long]).sum val baselineBytesWritten = f() Some(() => f() - baselineBytesWritten) } catch { - case e: NoSuchMethodException => { + case e @ (_: NoSuchMethodException | _: ClassNotFoundException) => { logDebug("Couldn't find method for retrieving thread-level FileSystem output data", e) None } } } - private def getFileSystemThreadStatistics(path: Path, conf: Configuration): Seq[AnyRef] = { - val qualifiedPath = path.getFileSystem(conf).makeQualified(path) - val scheme = qualifiedPath.toUri().getScheme() - val stats = FileSystem.getAllStatistics().filter(_.getScheme().equals(scheme)) + private def getFileSystemThreadStatistics(): Seq[AnyRef] = { + val stats = FileSystem.getAllStatistics() stats.map(Utils.invoke(classOf[Statistics], _, "getThreadStatistics")) } @@ -195,6 +186,21 @@ class SparkHadoopUtil extends Logging { val method = context.getClass.getMethod("getConfiguration") method.invoke(context).asInstanceOf[Configuration] } + + /** + * Get [[FileStatus]] objects for all leaf children (files) under the given base path. If the + * given path points to a file, return a single-element collection containing [[FileStatus]] of + * that file. + */ + def listLeafStatuses(fs: FileSystem, basePath: Path): Seq[FileStatus] = { + def recurse(path: Path) = { + val (directories, leaves) = fs.listStatus(path).partition(_.isDir) + leaves ++ directories.flatMap(f => listLeafStatuses(fs, f.getPath)) + } + + val baseStatus = fs.getFileStatus(basePath) + if (baseStatus.isDir) recurse(basePath) else Array(baseStatus) + } } object SparkHadoopUtil { diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 955cbd6dab96d..4c4110812e0a1 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -18,13 +18,37 @@ package org.apache.spark.deploy import java.io.{File, PrintStream} -import java.lang.reflect.{Modifier, InvocationTargetException} +import java.lang.reflect.{InvocationTargetException, Modifier, UndeclaredThrowableException} import java.net.URL +import java.security.PrivilegedExceptionAction import scala.collection.mutable.{ArrayBuffer, HashMap, Map} -import org.apache.spark.executor.ExecutorURLClassLoader -import org.apache.spark.util.Utils +import org.apache.hadoop.fs.Path +import org.apache.hadoop.security.UserGroupInformation +import org.apache.ivy.Ivy +import org.apache.ivy.core.LogOptions +import org.apache.ivy.core.module.descriptor._ +import org.apache.ivy.core.module.id.{ArtifactId, ModuleId, ModuleRevisionId} +import org.apache.ivy.core.report.ResolveReport +import org.apache.ivy.core.resolve.ResolveOptions +import org.apache.ivy.core.retrieve.RetrieveOptions +import org.apache.ivy.core.settings.IvySettings +import org.apache.ivy.plugins.matcher.GlobPatternMatcher +import org.apache.ivy.plugins.resolver.{ChainResolver, IBiblioResolver} + +import org.apache.spark.SPARK_VERSION +import org.apache.spark.deploy.rest._ +import org.apache.spark.util.{ChildFirstURLClassLoader, MutableURLClassLoader, Utils} + +/** + * Whether to submit, kill, or request the status of an application. + * The latter two operations are currently supported only for standalone cluster mode. + */ +private[spark] object SparkSubmitAction extends Enumeration { + type SparkSubmitAction = Value + val SUBMIT, KILL, REQUEST_STATUS = Value +} /** * Main gateway of launching a Spark application. @@ -57,35 +81,127 @@ object SparkSubmit { private val CLASS_NOT_FOUND_EXIT_STATUS = 101 // Exposed for testing - private[spark] var exitFn: () => Unit = () => System.exit(-1) + private[spark] var exitFn: () => Unit = () => System.exit(1) private[spark] var printStream: PrintStream = System.err - private[spark] def printWarning(str: String) = printStream.println("Warning: " + str) - private[spark] def printErrorAndExit(str: String) = { + private[spark] def printWarning(str: String): Unit = printStream.println("Warning: " + str) + private[spark] def printErrorAndExit(str: String): Unit = { printStream.println("Error: " + str) printStream.println("Run with --help for usage help or --verbose for debug output") exitFn() } + private[spark] def printVersionAndExit(): Unit = { + printStream.println("""Welcome to + ____ __ + / __/__ ___ _____/ /__ + _\ \/ _ \/ _ `/ __/ '_/ + /___/ .__/\_,_/_/ /_/\_\ version %s + /_/ + """.format(SPARK_VERSION)) + printStream.println("Type --help for more information.") + exitFn() + } - def main(args: Array[String]) { + def main(args: Array[String]): Unit = { val appArgs = new SparkSubmitArguments(args) if (appArgs.verbose) { printStream.println(appArgs) } - val (childArgs, classpath, sysProps, mainClass) = createLaunchEnv(appArgs) - launch(childArgs, classpath, sysProps, mainClass, appArgs.verbose) + appArgs.action match { + case SparkSubmitAction.SUBMIT => submit(appArgs) + case SparkSubmitAction.KILL => kill(appArgs) + case SparkSubmitAction.REQUEST_STATUS => requestStatus(appArgs) + } + } + + /** Kill an existing submission using the REST protocol. Standalone cluster mode only. */ + private def kill(args: SparkSubmitArguments): Unit = { + new StandaloneRestClient() + .killSubmission(args.master, args.submissionToKill) + } + + /** + * Request the status of an existing submission using the REST protocol. + * Standalone cluster mode only. + */ + private def requestStatus(args: SparkSubmitArguments): Unit = { + new StandaloneRestClient() + .requestSubmissionStatus(args.master, args.submissionToRequestStatusFor) } /** - * @return a tuple containing - * (1) the arguments for the child process, - * (2) a list of classpath entries for the child, - * (3) a list of system properties and env vars, and - * (4) the main class for the child + * Submit the application using the provided parameters. + * + * This runs in two steps. First, we prepare the launch environment by setting up + * the appropriate classpath, system properties, and application arguments for + * running the child main class based on the cluster manager and the deploy mode. + * Second, we use this launch environment to invoke the main method of the child + * main class. */ - private[spark] def createLaunchEnv(args: SparkSubmitArguments) - : (ArrayBuffer[String], ArrayBuffer[String], Map[String, String], String) = { + private[spark] def submit(args: SparkSubmitArguments): Unit = { + val (childArgs, childClasspath, sysProps, childMainClass) = prepareSubmitEnvironment(args) + + def doRunMain(): Unit = { + if (args.proxyUser != null) { + val proxyUser = UserGroupInformation.createProxyUser(args.proxyUser, + UserGroupInformation.getCurrentUser()) + try { + proxyUser.doAs(new PrivilegedExceptionAction[Unit]() { + override def run(): Unit = { + runMain(childArgs, childClasspath, sysProps, childMainClass, args.verbose) + } + }) + } catch { + case e: Exception => + // Hadoop's AuthorizationException suppresses the exception's stack trace, which + // makes the message printed to the output by the JVM not very helpful. Instead, + // detect exceptions with empty stack traces here, and treat them differently. + if (e.getStackTrace().length == 0) { + printStream.println(s"ERROR: ${e.getClass().getName()}: ${e.getMessage()}") + exitFn() + } else { + throw e + } + } + } else { + runMain(childArgs, childClasspath, sysProps, childMainClass, args.verbose) + } + } + + // In standalone cluster mode, there are two submission gateways: + // (1) The traditional Akka gateway using o.a.s.deploy.Client as a wrapper + // (2) The new REST-based gateway introduced in Spark 1.3 + // The latter is the default behavior as of Spark 1.3, but Spark submit will fail over + // to use the legacy gateway if the master endpoint turns out to be not a REST server. + if (args.isStandaloneCluster && args.useRest) { + try { + printStream.println("Running Spark using the REST application submission protocol.") + doRunMain() + } catch { + // Fail over to use the legacy submission gateway + case e: SubmitRestConnectionException => + printWarning(s"Master endpoint ${args.master} was not a REST server. " + + "Falling back to legacy submission gateway instead.") + args.useRest = false + submit(args) + } + // In all other modes, just run the main class as prepared + } else { + doRunMain() + } + } - // Values to return + /** + * Prepare the environment for submitting an application. + * This returns a 4-tuple: + * (1) the arguments for the child process, + * (2) a list of classpath entries for the child, + * (3) a map of system properties, and + * (4) the main class for the child + * Exposed for testing. + */ + private[spark] def prepareSubmitEnvironment(args: SparkSubmitArguments) + : (Seq[String], Seq[String], Map[String, String], String) = { + // Return values val childArgs = new ArrayBuffer[String]() val childClasspath = new ArrayBuffer[String]() val sysProps = new HashMap[String, String]() @@ -134,24 +250,60 @@ object SparkSubmit { } } + val isYarnCluster = clusterManager == YARN && deployMode == CLUSTER + + // Resolve maven dependencies if there are any and add classpath to jars. Add them to py-files + // too for packages that include Python code + val resolvedMavenCoordinates = + SparkSubmitUtils.resolveMavenCoordinates( + args.packages, Option(args.repositories), Option(args.ivyRepoPath)) + if (!resolvedMavenCoordinates.trim.isEmpty) { + if (args.jars == null || args.jars.trim.isEmpty) { + args.jars = resolvedMavenCoordinates + } else { + args.jars += s",$resolvedMavenCoordinates" + } + if (args.isPython) { + if (args.pyFiles == null || args.pyFiles.trim.isEmpty) { + args.pyFiles = resolvedMavenCoordinates + } else { + args.pyFiles += s",$resolvedMavenCoordinates" + } + } + } + + // Require all python files to be local, so we can add them to the PYTHONPATH + // In YARN cluster mode, python files are distributed as regular files, which can be non-local + if (args.isPython && !isYarnCluster) { + if (Utils.nonLocalPaths(args.primaryResource).nonEmpty) { + printErrorAndExit(s"Only local python files are supported: $args.primaryResource") + } + val nonLocalPyFiles = Utils.nonLocalPaths(args.pyFiles).mkString(",") + if (nonLocalPyFiles.nonEmpty) { + printErrorAndExit(s"Only local additional python files are supported: $nonLocalPyFiles") + } + } + // The following modes are not supported or applicable (clusterManager, deployMode) match { case (MESOS, CLUSTER) => printErrorAndExit("Cluster deploy mode is currently not supported for Mesos clusters.") - case (_, CLUSTER) if args.isPython => - printErrorAndExit("Cluster deploy mode is currently not supported for python applications.") + case (STANDALONE, CLUSTER) if args.isPython => + printErrorAndExit("Cluster deploy mode is currently not supported for python " + + "applications on standalone clusters.") case (_, CLUSTER) if isShell(args.primaryResource) => printErrorAndExit("Cluster deploy mode is not applicable to Spark shells.") case (_, CLUSTER) if isSqlShell(args.mainClass) => printErrorAndExit("Cluster deploy mode is not applicable to Spark SQL shell.") + case (_, CLUSTER) if isThriftServer(args.mainClass) => + printErrorAndExit("Cluster deploy mode is not applicable to Spark Thrift server.") case _ => } // If we're running a python app, set the main class to our specific python runner - if (args.isPython) { + if (args.isPython && deployMode == CLIENT) { if (args.primaryResource == PYSPARK_SHELL) { - args.mainClass = "py4j.GatewayServer" - args.childArgs = ArrayBuffer("--die-on-broken-pipe", "0") + args.mainClass = "org.apache.spark.api.python.PythonGatewayServer" } else { // If a python file is provided, add it to the child arguments and list of files to deploy. // Usage: PythonAppRunner
[app arguments] @@ -165,6 +317,13 @@ object SparkSubmit { } } + // In yarn-cluster mode for a python app, add primary resource and pyFiles to files + // that can be distributed with the job + if (args.isPython && isYarnCluster) { + args.files = mergeFileLists(args.files, args.primaryResource) + args.files = mergeFileLists(args.files, args.pyFiles) + } + // Special flag to avoid deprecation warnings at the client sysProps("SPARK_SUBMIT") = "true" @@ -176,6 +335,7 @@ object SparkSubmit { OptionAssigner(args.master, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, sysProp = "spark.master"), OptionAssigner(args.name, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, sysProp = "spark.app.name"), OptionAssigner(args.jars, ALL_CLUSTER_MGRS, CLIENT, sysProp = "spark.jars"), + OptionAssigner(args.ivyRepoPath, ALL_CLUSTER_MGRS, CLIENT, sysProp = "spark.jars.ivy"), OptionAssigner(args.driverMemory, ALL_CLUSTER_MGRS, CLIENT, sysProp = "spark.driver.memory"), OptionAssigner(args.driverExtraClassPath, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, @@ -186,9 +346,13 @@ object SparkSubmit { sysProp = "spark.driver.extraLibraryPath"), // Standalone cluster only + // Do not set CL arguments here because there are multiple possibilities for the main class OptionAssigner(args.jars, STANDALONE, CLUSTER, sysProp = "spark.jars"), - OptionAssigner(args.driverMemory, STANDALONE, CLUSTER, clOption = "--memory"), - OptionAssigner(args.driverCores, STANDALONE, CLUSTER, clOption = "--cores"), + OptionAssigner(args.ivyRepoPath, STANDALONE, CLUSTER, sysProp = "spark.jars.ivy"), + OptionAssigner(args.driverMemory, STANDALONE, CLUSTER, sysProp = "spark.driver.memory"), + OptionAssigner(args.driverCores, STANDALONE, CLUSTER, sysProp = "spark.driver.cores"), + OptionAssigner(args.supervise.toString, STANDALONE, CLUSTER, + sysProp = "spark.driver.supervise"), // Yarn client only OptionAssigner(args.queue, YARN, CLIENT, sysProp = "spark.yarn.queue"), @@ -200,6 +364,7 @@ object SparkSubmit { // Yarn cluster only OptionAssigner(args.name, YARN, CLUSTER, clOption = "--name"), OptionAssigner(args.driverMemory, YARN, CLUSTER, clOption = "--driver-memory"), + OptionAssigner(args.driverCores, YARN, CLUSTER, clOption = "--driver-cores"), OptionAssigner(args.queue, YARN, CLUSTER, clOption = "--queue"), OptionAssigner(args.numExecutors, YARN, CLUSTER, clOption = "--num-executors"), OptionAssigner(args.executorMemory, YARN, CLUSTER, clOption = "--executor-memory"), @@ -228,7 +393,6 @@ object SparkSubmit { if (args.childArgs != null) { childArgs ++= args.childArgs } } - // Map all arguments to command-line options or system properties for our chosen mode for (opt <- options) { if (opt.value != null && @@ -242,7 +406,6 @@ object SparkSubmit { // Add the application jar automatically so the user doesn't have to call sc.addJar // For YARN cluster mode, the jar is already distributed on each node as "app.jar" // For python files, the primary resource is already distributed as a regular file - val isYarnCluster = clusterManager == YARN && deployMode == CLUSTER if (!isYarnCluster && !args.isPython) { var jars = sysProps.get("spark.jars").map(x => x.split(",").toSeq).getOrElse(Seq.empty) if (isUserJar(args.primaryResource)) { @@ -251,14 +414,21 @@ object SparkSubmit { sysProps.put("spark.jars", jars.mkString(",")) } - // In standalone-cluster mode, use Client as a wrapper around the user class - if (clusterManager == STANDALONE && deployMode == CLUSTER) { - childMainClass = "org.apache.spark.deploy.Client" - if (args.supervise) { - childArgs += "--supervise" + // In standalone cluster mode, use the REST client to submit the application (Spark 1.3+). + // All Spark parameters are expected to be passed to the client through system properties. + if (args.isStandaloneCluster) { + if (args.useRest) { + childMainClass = "org.apache.spark.deploy.rest.StandaloneRestClient" + childArgs += (args.primaryResource, args.mainClass) + } else { + // In legacy standalone cluster mode, use Client as a wrapper around the user class + childMainClass = "org.apache.spark.deploy.Client" + if (args.supervise) { childArgs += "--supervise" } + Option(args.driverMemory).foreach { m => childArgs += ("--memory", m) } + Option(args.driverCores).foreach { c => childArgs += ("--cores", c) } + childArgs += "launch" + childArgs += (args.master, args.primaryResource, args.mainClass) } - childArgs += "launch" - childArgs += (args.master, args.primaryResource, args.mainClass) if (args.childArgs != null) { childArgs ++= args.childArgs } @@ -267,10 +437,22 @@ object SparkSubmit { // In yarn-cluster mode, use yarn.Client as a wrapper around the user class if (isYarnCluster) { childMainClass = "org.apache.spark.deploy.yarn.Client" - if (args.primaryResource != SPARK_INTERNAL) { - childArgs += ("--jar", args.primaryResource) + if (args.isPython) { + val mainPyFile = new Path(args.primaryResource).getName + childArgs += ("--primary-py-file", mainPyFile) + if (args.pyFiles != null) { + // These files will be distributed to each machine's working directory, so strip the + // path prefix + val pyFilesNames = args.pyFiles.split(",").map(p => (new Path(p)).getName).mkString(",") + childArgs += ("--py-files", pyFilesNames) + } + childArgs += ("--class", "org.apache.spark.deploy.PythonRunner") + } else { + if (args.primaryResource != SPARK_INTERNAL) { + childArgs += ("--jar", args.primaryResource) + } + childArgs += ("--class", args.mainClass) } - childArgs += ("--class", args.mainClass) if (args.childArgs != null) { args.childArgs.foreach { arg => childArgs += ("--arg", arg) } } @@ -283,7 +465,7 @@ object SparkSubmit { // Ignore invalid spark.driver.host in cluster modes. if (deployMode == CLUSTER) { - sysProps -= ("spark.driver.host") + sysProps -= "spark.driver.host" } // Resolve paths in certain spark properties @@ -312,12 +494,18 @@ object SparkSubmit { (childArgs, childClasspath, sysProps, childMainClass) } - private def launch( - childArgs: ArrayBuffer[String], - childClasspath: ArrayBuffer[String], + /** + * Run the main method of the child class using the provided launch environment. + * + * Note that this main class will not be the one provided by the user if we're + * running cluster deploy mode or python applications. + */ + private def runMain( + childArgs: Seq[String], + childClasspath: Seq[String], sysProps: Map[String, String], childMainClass: String, - verbose: Boolean = false) { + verbose: Boolean): Unit = { if (verbose) { printStream.println(s"Main class:\n$childMainClass") printStream.println(s"Arguments:\n${childArgs.mkString("\n")}") @@ -326,8 +514,14 @@ object SparkSubmit { printStream.println("\n") } - val loader = new ExecutorURLClassLoader(new Array[URL](0), - Thread.currentThread.getContextClassLoader) + val loader = + if (sysProps.getOrElse("spark.driver.userClassPathFirst", "false").toBoolean) { + new ChildFirstURLClassLoader(new Array[URL](0), + Thread.currentThread.getContextClassLoader) + } else { + new MutableURLClassLoader(new Array[URL](0), + Thread.currentThread.getContextClassLoader) + } Thread.currentThread.setContextClassLoader(loader) for (jar <- childClasspath) { @@ -346,8 +540,8 @@ object SparkSubmit { case e: ClassNotFoundException => e.printStackTrace(printStream) if (childMainClass.contains("thriftserver")) { - println(s"Failed to load main class $childMainClass.") - println("You need to build Spark with -Phive and -Phive-thriftserver.") + printStream.println(s"Failed to load main class $childMainClass.") + printStream.println("You need to build Spark with -Phive and -Phive-thriftserver.") } System.exit(CLASS_NOT_FOUND_EXIT_STATUS) } @@ -361,17 +555,25 @@ object SparkSubmit { if (!Modifier.isStatic(mainMethod.getModifiers)) { throw new IllegalStateException("The main method in the given main class must be static") } + + def findCause(t: Throwable): Throwable = t match { + case e: UndeclaredThrowableException => + if (e.getCause() != null) findCause(e.getCause()) else e + case e: InvocationTargetException => + if (e.getCause() != null) findCause(e.getCause()) else e + case e: Throwable => + e + } + try { mainMethod.invoke(null, childArgs.toArray) } catch { - case e: InvocationTargetException => e.getCause match { - case cause: Throwable => throw cause - case null => throw e - } + case t: Throwable => + throw findCause(t) } } - private def addJarToClasspath(localJar: String, loader: ExecutorURLClassLoader) { + private def addJarToClasspath(localJar: String, loader: MutableURLClassLoader) { val uri = Utils.resolveURI(localJar) uri.getScheme match { case "file" | "local" => @@ -407,6 +609,13 @@ object SparkSubmit { mainClass == "org.apache.spark.sql.hive.thriftserver.SparkSQLCLIDriver" } + /** + * Return whether the given main class represents a thrift server. + */ + private[spark] def isThriftServer(mainClass: String): Boolean = { + mainClass == "org.apache.spark.sql.hive.thriftserver.HiveThriftServer2" + } + /** * Return whether the given primary resource requires running python. */ @@ -430,11 +639,213 @@ object SparkSubmit { } } +/** Provides utility functions to be used inside SparkSubmit. */ +private[spark] object SparkSubmitUtils { + + // Exposed for testing + private[spark] var printStream = SparkSubmit.printStream + + /** + * Represents a Maven Coordinate + * @param groupId the groupId of the coordinate + * @param artifactId the artifactId of the coordinate + * @param version the version of the coordinate + */ + private[spark] case class MavenCoordinate(groupId: String, artifactId: String, version: String) + +/** + * Extracts maven coordinates from a comma-delimited string. Coordinates should be provided + * in the format `groupId:artifactId:version` or `groupId/artifactId:version`. The latter provides + * simplicity for Spark Package users. + * @param coordinates Comma-delimited string of maven coordinates + * @return Sequence of Maven coordinates + */ + private[spark] def extractMavenCoordinates(coordinates: String): Seq[MavenCoordinate] = { + coordinates.split(",").map { p => + val splits = p.replace("/", ":").split(":") + require(splits.length == 3, s"Provided Maven Coordinates must be in the form " + + s"'groupId:artifactId:version'. The coordinate provided is: $p") + require(splits(0) != null && splits(0).trim.nonEmpty, s"The groupId cannot be null or " + + s"be whitespace. The groupId provided is: ${splits(0)}") + require(splits(1) != null && splits(1).trim.nonEmpty, s"The artifactId cannot be null or " + + s"be whitespace. The artifactId provided is: ${splits(1)}") + require(splits(2) != null && splits(2).trim.nonEmpty, s"The version cannot be null or " + + s"be whitespace. The version provided is: ${splits(2)}") + new MavenCoordinate(splits(0), splits(1), splits(2)) + } + } + + /** + * Extracts maven coordinates from a comma-delimited string + * @param remoteRepos Comma-delimited string of remote repositories + * @return A ChainResolver used by Ivy to search for and resolve dependencies. + */ + private[spark] def createRepoResolvers(remoteRepos: Option[String]): ChainResolver = { + // We need a chain resolver if we want to check multiple repositories + val cr = new ChainResolver + cr.setName("list") + + // the biblio resolver resolves POM declared dependencies + val br: IBiblioResolver = new IBiblioResolver + br.setM2compatible(true) + br.setUsepoms(true) + br.setName("central") + cr.add(br) + + val sp: IBiblioResolver = new IBiblioResolver + sp.setM2compatible(true) + sp.setUsepoms(true) + sp.setRoot("http://dl.bintray.com/spark-packages/maven") + sp.setName("spark-packages") + cr.add(sp) + + val repositoryList = remoteRepos.getOrElse("") + // add any other remote repositories other than maven central + if (repositoryList.trim.nonEmpty) { + repositoryList.split(",").zipWithIndex.foreach { case (repo, i) => + val brr: IBiblioResolver = new IBiblioResolver + brr.setM2compatible(true) + brr.setUsepoms(true) + brr.setRoot(repo) + brr.setName(s"repo-${i + 1}") + cr.add(brr) + printStream.println(s"$repo added as a remote repository with the name: ${brr.getName}") + } + } + cr + } + + /** + * Output a comma-delimited list of paths for the downloaded jars to be added to the classpath + * (will append to jars in SparkSubmit). The name of the jar is given + * after a '!' by Ivy. It also sometimes contains '(bundle)' after '.jar'. Remove that as well. + * @param artifacts Sequence of dependencies that were resolved and retrieved + * @param cacheDirectory directory where jars are cached + * @return a comma-delimited list of paths for the dependencies + */ + private[spark] def resolveDependencyPaths( + artifacts: Array[AnyRef], + cacheDirectory: File): String = { + artifacts.map { artifactInfo => + val artifactString = artifactInfo.toString + val jarName = artifactString.drop(artifactString.lastIndexOf("!") + 1) + cacheDirectory.getAbsolutePath + File.separator + + jarName.substring(0, jarName.lastIndexOf(".jar") + 4) + }.mkString(",") + } + + /** Adds the given maven coordinates to Ivy's module descriptor. */ + private[spark] def addDependenciesToIvy( + md: DefaultModuleDescriptor, + artifacts: Seq[MavenCoordinate], + ivyConfName: String): Unit = { + artifacts.foreach { mvn => + val ri = ModuleRevisionId.newInstance(mvn.groupId, mvn.artifactId, mvn.version) + val dd = new DefaultDependencyDescriptor(ri, false, false) + dd.addDependencyConfiguration(ivyConfName, ivyConfName) + printStream.println(s"${dd.getDependencyId} added as a dependency") + md.addDependency(dd) + } + } + + /** A nice function to use in tests as well. Values are dummy strings. */ + private[spark] def getModuleDescriptor = DefaultModuleDescriptor.newDefaultInstance( + ModuleRevisionId.newInstance("org.apache.spark", "spark-submit-parent", "1.0")) + + /** + * Resolves any dependencies that were supplied through maven coordinates + * @param coordinates Comma-delimited string of maven coordinates + * @param remoteRepos Comma-delimited string of remote repositories other than maven central + * @param ivyPath The path to the local ivy repository + * @return The comma-delimited path to the jars of the given maven artifacts including their + * transitive dependencies + */ + private[spark] def resolveMavenCoordinates( + coordinates: String, + remoteRepos: Option[String], + ivyPath: Option[String], + isTest: Boolean = false): String = { + if (coordinates == null || coordinates.trim.isEmpty) { + "" + } else { + val artifacts = extractMavenCoordinates(coordinates) + // Default configuration name for ivy + val ivyConfName = "default" + // set ivy settings for location of cache + val ivySettings: IvySettings = new IvySettings + // Directories for caching downloads through ivy and storing the jars when maven coordinates + // are supplied to spark-submit + val alternateIvyCache = ivyPath.getOrElse("") + val packagesDirectory: File = + if (alternateIvyCache.trim.isEmpty) { + new File(ivySettings.getDefaultIvyUserDir, "jars") + } else { + ivySettings.setDefaultCache(new File(alternateIvyCache, "cache")) + new File(alternateIvyCache, "jars") + } + printStream.println( + s"Ivy Default Cache set to: ${ivySettings.getDefaultCache.getAbsolutePath}") + printStream.println(s"The jars for the packages stored in: $packagesDirectory") + // create a pattern matcher + ivySettings.addMatcher(new GlobPatternMatcher) + // create the dependency resolvers + val repoResolver = createRepoResolvers(remoteRepos) + ivySettings.addResolver(repoResolver) + ivySettings.setDefaultResolver(repoResolver.getName) + + val ivy = Ivy.newInstance(ivySettings) + // Set resolve options to download transitive dependencies as well + val resolveOptions = new ResolveOptions + resolveOptions.setTransitive(true) + val retrieveOptions = new RetrieveOptions + // Turn downloading and logging off for testing + if (isTest) { + resolveOptions.setDownload(false) + resolveOptions.setLog(LogOptions.LOG_QUIET) + retrieveOptions.setLog(LogOptions.LOG_QUIET) + } else { + resolveOptions.setDownload(true) + } + + // A Module descriptor must be specified. Entries are dummy strings + val md = getModuleDescriptor + md.setDefaultConf(ivyConfName) + + // Add an exclusion rule for Spark and Scala Library + val sparkArtifacts = new ArtifactId(new ModuleId("org.apache.spark", "*"), "*", "*", "*") + val sparkDependencyExcludeRule = + new DefaultExcludeRule(sparkArtifacts, ivySettings.getMatcher("glob"), null) + sparkDependencyExcludeRule.addConfiguration(ivyConfName) + val scalaArtifacts = new ArtifactId(new ModuleId("*", "scala-library"), "*", "*", "*") + val scalaDependencyExcludeRule = + new DefaultExcludeRule(scalaArtifacts, ivySettings.getMatcher("glob"), null) + scalaDependencyExcludeRule.addConfiguration(ivyConfName) + + // Exclude any Spark dependencies, and add all supplied maven artifacts as dependencies + md.addExcludeRule(sparkDependencyExcludeRule) + md.addExcludeRule(scalaDependencyExcludeRule) + addDependenciesToIvy(md, artifacts, ivyConfName) + + // resolve dependencies + val rr: ResolveReport = ivy.resolve(md, resolveOptions) + if (rr.hasError) { + throw new RuntimeException(rr.getAllProblemMessages.toString) + } + // retrieve all resolved dependencies + ivy.retrieve(rr.getModuleDescriptor.getModuleRevisionId, + packagesDirectory.getAbsolutePath + File.separator + "[artifact](-[classifier]).[ext]", + retrieveOptions.setConfs(Array(ivyConfName))) + + resolveDependencyPaths(rr.getArtifacts.toArray, packagesDirectory) + } + } +} + /** * Provides an indirection layer for passing arguments as system properties or flags to * the user's driver program or to downstream launcher tools. */ -private[spark] case class OptionAssigner( +private case class OptionAssigner( value: String, clusterManager: Int, deployMode: Int, diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index 47059b08a397f..82e66a374249c 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -22,6 +22,7 @@ import java.util.jar.JarFile import scala.collection.mutable.{ArrayBuffer, HashMap} +import org.apache.spark.deploy.SparkSubmitAction._ import org.apache.spark.util.Utils /** @@ -39,8 +40,6 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St var driverExtraClassPath: String = null var driverExtraLibraryPath: String = null var driverExtraJavaOptions: String = null - var driverCores: String = null - var supervise: Boolean = false var queue: String = null var numExecutors: String = null var files: String = null @@ -50,10 +49,22 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St var name: String = null var childArgs: ArrayBuffer[String] = new ArrayBuffer[String]() var jars: String = null + var packages: String = null + var repositories: String = null + var ivyRepoPath: String = null var verbose: Boolean = false var isPython: Boolean = false var pyFiles: String = null + var action: SparkSubmitAction = null val sparkProperties: HashMap[String, String] = new HashMap[String, String]() + var proxyUser: String = null + + // Standalone cluster mode only + var supervise: Boolean = false + var driverCores: String = null + var submissionToKill: String = null + var submissionToRequestStatusFor: String = null + var useRest: Boolean = true // used internally /** Default properties present in the currently defined defaults file. */ lazy val defaultSparkProperties: HashMap[String, String] = { @@ -79,7 +90,7 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St // Use `sparkProperties` map along with env vars to fill in any missing parameters loadEnvironmentArguments() - checkRequiredArguments() + validateArguments() /** * Merge values from the default properties file with those specified through --conf. @@ -104,10 +115,22 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St .orElse(sparkProperties.get("spark.master")) .orElse(env.get("MASTER")) .orNull + driverExtraClassPath = Option(driverExtraClassPath) + .orElse(sparkProperties.get("spark.driver.extraClassPath")) + .orNull + driverExtraJavaOptions = Option(driverExtraJavaOptions) + .orElse(sparkProperties.get("spark.driver.extraJavaOptions")) + .orNull + driverExtraLibraryPath = Option(driverExtraLibraryPath) + .orElse(sparkProperties.get("spark.driver.extraLibraryPath")) + .orNull driverMemory = Option(driverMemory) .orElse(sparkProperties.get("spark.driver.memory")) .orElse(env.get("SPARK_DRIVER_MEMORY")) .orNull + driverCores = Option(driverCores) + .orElse(sparkProperties.get("spark.driver.cores")) + .orNull executorMemory = Option(executorMemory) .orElse(sparkProperties.get("spark.executor.memory")) .orElse(env.get("SPARK_EXECUTOR_MEMORY")) @@ -120,6 +143,7 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St .orNull name = Option(name).orElse(sparkProperties.get("spark.app.name")).orNull jars = Option(jars).orElse(sparkProperties.get("spark.jars")).orNull + ivyRepoPath = sparkProperties.get("spark.jars.ivy").orNull deployMode = Option(deployMode).orElse(env.get("DEPLOY_MODE")).orNull numExecutors = Option(numExecutors) .getOrElse(sparkProperties.get("spark.executor.instances").orNull) @@ -159,10 +183,21 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St if (name == null && primaryResource != null) { name = Utils.stripDirectory(primaryResource) } + + // Action should be SUBMIT unless otherwise specified + action = Option(action).getOrElse(SUBMIT) } /** Ensure that required fields exists. Call this only once all defaults are loaded. */ - private def checkRequiredArguments(): Unit = { + private def validateArguments(): Unit = { + action match { + case SUBMIT => validateSubmitArguments() + case KILL => validateKillArguments() + case REQUEST_STATUS => validateStatusRequestArguments() + } + } + + private def validateSubmitArguments(): Unit = { if (args.length == 0) { printUsageAndExit(-1) } @@ -176,18 +211,6 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St SparkSubmit.printErrorAndExit("--py-files given but primary resource is not a Python script") } - // Require all python files to be local, so we can add them to the PYTHONPATH - if (isPython) { - if (Utils.nonLocalPaths(primaryResource).nonEmpty) { - SparkSubmit.printErrorAndExit(s"Only local python files are supported: $primaryResource") - } - val nonLocalPyFiles = Utils.nonLocalPaths(pyFiles).mkString(",") - if (nonLocalPyFiles.nonEmpty) { - SparkSubmit.printErrorAndExit( - s"Only local additional python files are supported: $nonLocalPyFiles") - } - } - if (master.startsWith("yarn")) { val hasHadoopEnv = env.contains("HADOOP_CONF_DIR") || env.contains("YARN_CONF_DIR") if (!hasHadoopEnv && !Utils.isTesting) { @@ -197,6 +220,29 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St } } + private def validateKillArguments(): Unit = { + if (!master.startsWith("spark://")) { + SparkSubmit.printErrorAndExit("Killing submissions is only supported in standalone mode!") + } + if (submissionToKill == null) { + SparkSubmit.printErrorAndExit("Please specify a submission to kill.") + } + } + + private def validateStatusRequestArguments(): Unit = { + if (!master.startsWith("spark://")) { + SparkSubmit.printErrorAndExit( + "Requesting submission statuses is only supported in standalone mode!") + } + if (submissionToRequestStatusFor == null) { + SparkSubmit.printErrorAndExit("Please specify a submission to request status for.") + } + } + + def isStandaloneCluster: Boolean = { + master.startsWith("spark://") && deployMode == "cluster" + } + override def toString = { s"""Parsed arguments: | master $master @@ -221,6 +267,8 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St | name $name | childArgs [${childArgs.mkString(" ")}] | jars $jars + | packages $packages + | repositories $repositories | verbose $verbose | |Spark properties used, including those specified through @@ -303,6 +351,22 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St propertiesFile = value parse(tail) + case ("--kill") :: value :: tail => + submissionToKill = value + if (action != null) { + SparkSubmit.printErrorAndExit(s"Action cannot be both $action and $KILL.") + } + action = KILL + parse(tail) + + case ("--status") :: value :: tail => + submissionToRequestStatusFor = value + if (action != null) { + SparkSubmit.printErrorAndExit(s"Action cannot be both $action and $REQUEST_STATUS.") + } + action = REQUEST_STATUS + parse(tail) + case ("--supervise") :: tail => supervise = true parse(tail) @@ -327,6 +391,14 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St jars = Utils.resolveURIs(value) parse(tail) + case ("--packages") :: value :: tail => + packages = value + parse(tail) + + case ("--repositories") :: value :: tail => + repositories = value + parse(tail) + case ("--conf" | "-c") :: value :: tail => value.split("=", 2).toSeq match { case Seq(k, v) => sparkProperties(k) = v @@ -334,6 +406,10 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St } parse(tail) + case ("--proxy-user") :: value :: tail => + proxyUser = value + parse(tail) + case ("--help" | "-h") :: tail => printUsageAndExit(0) @@ -341,6 +417,9 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St verbose = true parse(tail) + case ("--version") :: tail => + SparkSubmit.printVersionAndExit() + case EQ_SEPARATED_OPT(opt, value) :: tail => parse(opt :: value :: tail) @@ -367,7 +446,10 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St outStream.println("Unknown/unsupported param " + unknownParam) } outStream.println( - """Usage: spark-submit [options] [app options] + """Usage: spark-submit [options] [app arguments] + |Usage: spark-submit --kill [submission ID] --master [spark://...] + |Usage: spark-submit --status [submission ID] --master [spark://...] + | |Options: | --master MASTER_URL spark://host:port, mesos://host:port, yarn, or local. | --deploy-mode DEPLOY_MODE Whether to launch the driver program locally ("client") or @@ -377,6 +459,13 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St | --name NAME A name of your application. | --jars JARS Comma-separated list of local jars to include on the driver | and executor classpaths. + | --packages Comma-separated list of maven coordinates of jars to include + | on the driver and executor classpaths. Will search the local + | maven repo, then maven central and any additional remote + | repositories given by --repositories. The format for the + | coordinates should be groupId:artifactId:version. + | --repositories Comma-separated list of additional remote repositories to + | search for the maven coordinates given with --packages. | --py-files PY_FILES Comma-separated list of .zip, .egg, or .py files to place | on the PYTHONPATH for Python apps. | --files FILES Comma-separated list of files to be placed in the working @@ -395,17 +484,24 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St | | --executor-memory MEM Memory per executor (e.g. 1000M, 2G) (Default: 1G). | + | --proxy-user NAME User to impersonate when submitting the application. + | | --help, -h Show this help message and exit | --verbose, -v Print additional debug output + | --version, Print the version of current Spark | | Spark standalone with cluster deploy mode only: | --driver-cores NUM Cores for driver (Default: 1). | --supervise If given, restarts the driver on failure. + | --kill SUBMISSION_ID If given, kills the driver specified. + | --status SUBMISSION_ID If given, requests the status of the driver specified. | | Spark standalone and Mesos only: | --total-executor-cores NUM Total cores for all executors. | | YARN-only: + | --driver-cores NUM Number of cores used by the driver, only in cluster mode + | (Default: 1). | --executor-cores NUM Number of cores per executor (Default: 1). | --queue QUEUE_NAME The YARN queue to submit to (Default: "default"). | --num-executors NUM Number of executors to launch (Default: 2). diff --git a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala index 39a7b0319b6a1..ffe940fbda2fb 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala @@ -47,7 +47,7 @@ private[spark] class AppClient( conf: SparkConf) extends Logging { - val masterAkkaUrls = masterUrls.map(Master.toAkkaUrl) + val masterAkkaUrls = masterUrls.map(Master.toAkkaUrl(_, AkkaUtils.protocol(actorSystem))) val REGISTRATION_TIMEOUT = 20.seconds val REGISTRATION_RETRIES = 3 @@ -107,8 +107,9 @@ private[spark] class AppClient( def changeMaster(url: String) { // activeMasterUrl is a valid Spark url since we receive it from master. activeMasterUrl = url - master = context.actorSelection(Master.toAkkaUrl(activeMasterUrl)) - masterAddress = Master.toAkkaAddress(activeMasterUrl) + master = context.actorSelection( + Master.toAkkaUrl(activeMasterUrl, AkkaUtils.protocol(actorSystem))) + masterAddress = Master.toAkkaAddress(activeMasterUrl, AkkaUtils.protocol(actorSystem)) } private def isPossibleMaster(remoteUrl: Address) = { diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index 2b084a2d73b78..885fa0fdbf85b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -173,9 +173,10 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis val logInfos = statusList .filter { entry => try { - val modTime = getModificationTime(entry) - newLastModifiedTime = math.max(newLastModifiedTime, modTime) - modTime >= lastModifiedTime + getModificationTime(entry).map { time => + newLastModifiedTime = math.max(newLastModifiedTime, time) + time >= lastModifiedTime + }.getOrElse(false) } catch { case e: AccessControlException => // Do not use "logInfo" since these messages can get pretty noisy if printed on @@ -193,7 +194,7 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis None } } - .sortBy { info => (-info.endTime, -info.startTime) } + .sortWith(compareAppInfo) lastModifiedTime = newLastModifiedTime @@ -203,7 +204,9 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis if (!logInfos.isEmpty) { val newApps = new mutable.LinkedHashMap[String, FsApplicationHistoryInfo]() def addIfAbsent(info: FsApplicationHistoryInfo) = { - if (!newApps.contains(info.id)) { + if (!newApps.contains(info.id) || + newApps(info.id).logPath.endsWith(EventLoggingListener.IN_PROGRESS) && + !info.logPath.endsWith(EventLoggingListener.IN_PROGRESS)) { newApps += (info.id -> info) } } @@ -211,7 +214,7 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis val newIterator = logInfos.iterator.buffered val oldIterator = applications.values.iterator.buffered while (newIterator.hasNext && oldIterator.hasNext) { - if (newIterator.head.endTime > oldIterator.head.endTime) { + if (compareAppInfo(newIterator.head, oldIterator.head)) { addIfAbsent(newIterator.next) } else { addIfAbsent(oldIterator.next) @@ -227,12 +230,24 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis } } + /** + * Comparison function that defines the sort order for the application listing. + * + * @return Whether `i1` should precede `i2`. + */ + private def compareAppInfo( + i1: FsApplicationHistoryInfo, + i2: FsApplicationHistoryInfo): Boolean = { + if (i1.endTime != i2.endTime) i1.endTime >= i2.endTime else i1.startTime >= i2.startTime + } + /** * Replays the events in the specified log file and returns information about the associated * application. */ private def replay(eventLog: FileStatus, bus: ReplayListenerBus): FsApplicationHistoryInfo = { val logPath = eventLog.getPath() + logInfo(s"Replaying log path: $logPath") val (logInput, sparkVersion) = if (isLegacyLogDirectory(eventLog)) { openLegacyEventLog(logPath) @@ -242,14 +257,14 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis try { val appListener = new ApplicationEventListener bus.addListener(appListener) - bus.replay(logInput, sparkVersion) + bus.replay(logInput, sparkVersion, logPath.toString) new FsApplicationHistoryInfo( logPath.getName(), appListener.appId.getOrElse(logPath.getName()), appListener.appName.getOrElse(NOT_STARTED), appListener.startTime.getOrElse(-1L), appListener.endTime.getOrElse(-1L), - getModificationTime(eventLog), + getModificationTime(eventLog).get, appListener.sparkUser.getOrElse(NOT_STARTED), isApplicationCompleted(eventLog)) } finally { @@ -308,11 +323,16 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis */ private def isLegacyLogDirectory(entry: FileStatus): Boolean = entry.isDir() - private def getModificationTime(fsEntry: FileStatus): Long = { - if (fsEntry.isDir) { - fs.listStatus(fsEntry.getPath).map(_.getModificationTime()).max + /** + * Returns the modification time of the given event log. If the status points at an empty + * directory, `None` is returned, indicating that there isn't an event log at that location. + */ + private def getModificationTime(fsEntry: FileStatus): Option[Long] = { + if (isLegacyLogDirectory(fsEntry)) { + val statusList = fs.listStatus(fsEntry.getPath) + if (!statusList.isEmpty) Some(statusList.map(_.getModificationTime()).max) else None } else { - fsEntry.getModificationTime() + Some(fsEntry.getModificationTime()) } } diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala index e4e7bc2216014..26ebc75971c66 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala @@ -61,9 +61,10 @@ private[spark] class HistoryPage(parent: HistoryServer) extends WebUIPage("") { // page, `...` will be displayed. if (allApps.size > 0) { val leftSideIndices = - rangeIndices(actualPage - plusOrMinus until actualPage, 1 < _) + rangeIndices(actualPage - plusOrMinus until actualPage, 1 < _, requestedIncomplete) val rightSideIndices = - rangeIndices(actualPage + 1 to actualPage + plusOrMinus, _ < pageCount) + rangeIndices(actualPage + 1 to actualPage + plusOrMinus, _ < pageCount, + requestedIncomplete)

Showing {actualFirst + 1}-{last + 1} of {allApps.size} @@ -122,8 +123,10 @@ private[spark] class HistoryPage(parent: HistoryServer) extends WebUIPage("") { "Spark User", "Last Updated") - private def rangeIndices(range: Seq[Int], condition: Int => Boolean): Seq[Node] = { - range.filter(condition).map(nextPage => {nextPage} ) + private def rangeIndices(range: Seq[Int], condition: Int => Boolean, showIncomplete: Boolean): + Seq[Node] = { + range.filter(condition).map(nextPage => + {nextPage} ) } private def appRow(info: ApplicationHistoryInfo): Seq[Node] = { diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala index ede0a9dbefb8d..a962dc4af2f6c 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala @@ -90,9 +90,9 @@ private[spark] class ApplicationInfo( } } - private val myMaxCores = desc.maxCores.getOrElse(defaultCores) + val requestedCores = desc.maxCores.getOrElse(defaultCores) - def coresLeft: Int = myMaxCores - coresGranted + def coresLeft: Int = requestedCores - coresGranted private var _retryCount = 0 diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index d92d99310a583..8cc6ec1e8192c 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -43,6 +43,7 @@ import org.apache.spark.deploy.history.HistoryServer import org.apache.spark.deploy.master.DriverState.DriverState import org.apache.spark.deploy.master.MasterMessages._ import org.apache.spark.deploy.master.ui.MasterWebUI +import org.apache.spark.deploy.rest.StandaloneRestServer import org.apache.spark.metrics.MetricsSystem import org.apache.spark.scheduler.{EventLoggingListener, ReplayListenerBus} import org.apache.spark.ui.SparkUI @@ -52,12 +53,12 @@ private[spark] class Master( host: String, port: Int, webUiPort: Int, - val securityMgr: SecurityManager) + val securityMgr: SecurityManager, + val conf: SparkConf) extends Actor with ActorLogReceive with Logging with LeaderElectable { import context.dispatcher // to use Akka's scheduler.schedule() - val conf = new SparkConf val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf) def createDateFormat = new SimpleDateFormat("yyyyMMddHHmmss") // For application IDs @@ -121,6 +122,17 @@ private[spark] class Master( throw new SparkException("spark.deploy.defaultCores must be positive") } + // Alternative application submission gateway that is stable across Spark versions + private val restServerEnabled = conf.getBoolean("spark.master.rest.enabled", true) + private val restServer = + if (restServerEnabled) { + val port = conf.getInt("spark.master.rest.port", 6066) + Some(new StandaloneRestServer(host, port, self, masterUrl, conf)) + } else { + None + } + private val restServerBoundPort = restServer.map(_.start()) + override def preStart() { logInfo("Starting Spark master at " + masterUrl) logInfo(s"Running Spark version ${org.apache.spark.SPARK_VERSION}") @@ -174,6 +186,7 @@ private[spark] class Master( recoveryCompletionTask.cancel() } webUi.stop() + restServer.foreach(_.stop()) masterMetricsSystem.stop() applicationMetricsSystem.stop() persistenceEngine.close() @@ -421,7 +434,9 @@ private[spark] class Master( } case RequestMasterState => { - sender ! MasterStateResponse(host, port, workers.toArray, apps.toArray, completedApps.toArray, + sender ! MasterStateResponse( + host, port, restServerBoundPort, + workers.toArray, apps.toArray, completedApps.toArray, drivers.toArray, completedDrivers.toArray, state) } @@ -429,8 +444,8 @@ private[spark] class Master( timeOutDeadWorkers() } - case RequestWebUIPort => { - sender ! WebUIPortResponse(webUi.boundPort) + case BoundPortsRequest => { + sender ! BoundPortsResponse(port, webUi.boundPort, restServerBoundPort) } } @@ -656,7 +671,7 @@ private[spark] class Master( def registerApplication(app: ApplicationInfo): Unit = { val appAddress = app.driver.path.address - if (addressToWorker.contains(appAddress)) { + if (addressToApp.contains(appAddress)) { logInfo("Attempted to re-register application at same address: " + appAddress) return } @@ -746,7 +761,7 @@ private[spark] class Master( val ui = SparkUI.createHistoryUI(new SparkConf, replayBus, new SecurityManager(conf), appName + " (completed)", HistoryServer.UI_PATH_PREFIX + s"/${app.id}") try { - replayBus.replay(logInput, sparkVersion) + replayBus.replay(logInput, sparkVersion, eventLogFile) } finally { logInput.close() } @@ -851,7 +866,7 @@ private[spark] object Master extends Logging { SignalLogger.register(log) val conf = new SparkConf val args = new MasterArguments(argStrings, conf) - val (actorSystem, _, _) = startSystemAndActor(args.host, args.port, args.webUiPort, conf) + val (actorSystem, _, _, _) = startSystemAndActor(args.host, args.port, args.webUiPort, conf) actorSystem.awaitTermination() } @@ -860,9 +875,9 @@ private[spark] object Master extends Logging { * * @throws SparkException if the url is invalid */ - def toAkkaUrl(sparkUrl: String): String = { + def toAkkaUrl(sparkUrl: String, protocol: String): String = { val (host, port) = Utils.extractHostPortFromSparkUrl(sparkUrl) - "akka.tcp://%s@%s:%s/user/%s".format(systemName, host, port, actorName) + AkkaUtils.address(protocol, systemName, host, port, actorName) } /** @@ -870,24 +885,31 @@ private[spark] object Master extends Logging { * * @throws SparkException if the url is invalid */ - def toAkkaAddress(sparkUrl: String): Address = { + def toAkkaAddress(sparkUrl: String, protocol: String): Address = { val (host, port) = Utils.extractHostPortFromSparkUrl(sparkUrl) - Address("akka.tcp", systemName, host, port) + Address(protocol, systemName, host, port) } + /** + * Start the Master and return a four tuple of: + * (1) The Master actor system + * (2) The bound port + * (3) The web UI bound port + * (4) The REST server bound port, if any + */ def startSystemAndActor( host: String, port: Int, webUiPort: Int, - conf: SparkConf): (ActorSystem, Int, Int) = { + conf: SparkConf): (ActorSystem, Int, Int, Option[Int]) = { val securityMgr = new SecurityManager(conf) val (actorSystem, boundPort) = AkkaUtils.createActorSystem(systemName, host, port, conf = conf, securityManager = securityMgr) - val actor = actorSystem.actorOf(Props(classOf[Master], host, boundPort, webUiPort, - securityMgr), actorName) + val actor = actorSystem.actorOf( + Props(classOf[Master], host, boundPort, webUiPort, securityMgr, conf), actorName) val timeout = AkkaUtils.askTimeout(conf) - val respFuture = actor.ask(RequestWebUIPort)(timeout) - val resp = Await.result(respFuture, timeout).asInstanceOf[WebUIPortResponse] - (actorSystem, boundPort, resp.webUIBoundPort) + val portsRequest = actor.ask(BoundPortsRequest)(timeout) + val portsResponse = Await.result(portsRequest, timeout).asInstanceOf[BoundPortsResponse] + (actorSystem, boundPort, portsResponse.webUIPort, portsResponse.restPort) } } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/MasterMessages.scala b/core/src/main/scala/org/apache/spark/deploy/master/MasterMessages.scala index db72d8ae9bdaf..15c6296888f70 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/MasterMessages.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/MasterMessages.scala @@ -36,7 +36,7 @@ private[master] object MasterMessages { case object CompleteRecovery - case object RequestWebUIPort + case object BoundPortsRequest - case class WebUIPortResponse(webUIBoundPort: Int) + case class BoundPortsResponse(actorPort: Int, webUIPort: Int, restPort: Option[Int]) } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala index 7ca3b08a28728..9dd96493ee48d 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala @@ -46,19 +46,23 @@ private[spark] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { val stateFuture = (master ? RequestMasterState)(timeout).mapTo[MasterStateResponse] val state = Await.result(stateFuture, timeout) - val workerHeaders = Seq("Id", "Address", "State", "Cores", "Memory") + val workerHeaders = Seq("Worker Id", "Address", "State", "Cores", "Memory") val workers = state.workers.sortBy(_.id) val workerTable = UIUtils.listingTable(workerHeaders, workerRow, workers) - val appHeaders = Seq("ID", "Name", "Cores", "Memory per Node", "Submitted Time", "User", - "State", "Duration") + val activeAppHeaders = Seq("Application ID", "Name", "Cores in Use", + "Cores Requested", "Memory per Node", "Submitted Time", "User", "State", "Duration") val activeApps = state.activeApps.sortBy(_.startTime).reverse - val activeAppsTable = UIUtils.listingTable(appHeaders, appRow, activeApps) + val activeAppsTable = UIUtils.listingTable(activeAppHeaders, activeAppRow, activeApps) + + val completedAppHeaders = Seq("Application ID", "Name", "Cores Requested", "Memory per Node", + "Submitted Time", "User", "State", "Duration") val completedApps = state.completedApps.sortBy(_.endTime).reverse - val completedAppsTable = UIUtils.listingTable(appHeaders, appRow, completedApps) + val completedAppsTable = UIUtils.listingTable(completedAppHeaders, completeAppRow, + completedApps) - val driverHeaders = Seq("ID", "Submitted Time", "Worker", "State", "Cores", "Memory", - "Main Class") + val driverHeaders = Seq("Submission ID", "Submitted Time", "Worker", "State", "Cores", + "Memory", "Main Class") val activeDrivers = state.activeDrivers.sortBy(_.startTime).reverse val activeDriversTable = UIUtils.listingTable(driverHeaders, driverRow, activeDrivers) val completedDrivers = state.completedDrivers.sortBy(_.startTime).reverse @@ -73,6 +77,14 @@ private[spark] class MasterPage(parent: MasterWebUI) extends WebUIPage("") {
  • URL: {state.uri}
  • + { + state.restUri.map { uri => +
  • + REST URL: {uri} + (cluster mode) +
  • + }.getOrElse { Seq.empty } + }
  • Workers: {state.workers.size}
  • Cores: {state.workers.map(_.cores).sum} Total, {state.workers.map(_.coresUsed).sum} Used
  • @@ -154,7 +166,7 @@ private[spark] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { } - private def appRow(app: ApplicationInfo): Seq[Node] = { + private def appRow(app: ApplicationInfo, active: Boolean): Seq[Node] = { {app.id} @@ -162,8 +174,15 @@ private[spark] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { {app.desc.name} + { + if (active) { + + {app.coresGranted} + + } + } - {app.coresGranted} + {app.requestedCores} {Utils.megabytesToString(app.desc.memoryPerSlave)} @@ -175,6 +194,14 @@ private[spark] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { } + private def activeAppRow(app: ApplicationInfo): Seq[Node] = { + appRow(app, active = true) + } + + private def completeAppRow(app: ApplicationInfo): Seq[Node] = { + appRow(app, active = false) + } + private def driverRow(driver: DriverInfo): Seq[Node] = { {driver.id} @@ -188,7 +215,7 @@ private[spark] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { {Utils.megabytesToString(driver.desc.mem.toLong)} - {driver.desc.command.arguments(1)} + {driver.desc.command.arguments(2)} } } diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala new file mode 100644 index 0000000000000..c4be1f19e8e9f --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala @@ -0,0 +1,331 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.rest + +import java.io.{DataOutputStream, FileNotFoundException} +import java.net.{HttpURLConnection, SocketException, URL} +import javax.servlet.http.HttpServletResponse + +import scala.io.Source + +import com.fasterxml.jackson.core.JsonProcessingException +import com.google.common.base.Charsets + +import org.apache.spark.{Logging, SparkConf, SPARK_VERSION => sparkVersion} + +/** + * A client that submits applications to the standalone Master using a REST protocol. + * This client is intended to communicate with the [[StandaloneRestServer]] and is + * currently used for cluster mode only. + * + * In protocol version v1, the REST URL takes the form http://[host:port]/v1/submissions/[action], + * where [action] can be one of create, kill, or status. Each type of request is represented in + * an HTTP message sent to the following prefixes: + * (1) submit - POST to /submissions/create + * (2) kill - POST /submissions/kill/[submissionId] + * (3) status - GET /submissions/status/[submissionId] + * + * In the case of (1), parameters are posted in the HTTP body in the form of JSON fields. + * Otherwise, the URL fully specifies the intended action of the client. + * + * Since the protocol is expected to be stable across Spark versions, existing fields cannot be + * added or removed, though new optional fields can be added. In the rare event that forward or + * backward compatibility is broken, Spark must introduce a new protocol version (e.g. v2). + * + * The client and the server must communicate using the same version of the protocol. If there + * is a mismatch, the server will respond with the highest protocol version it supports. A future + * implementation of this client can use that information to retry using the version specified + * by the server. + */ +private[spark] class StandaloneRestClient extends Logging { + import StandaloneRestClient._ + + /** + * Submit an application specified by the parameters in the provided request. + * + * If the submission was successful, poll the status of the submission and report + * it to the user. Otherwise, report the error message provided by the server. + */ + def createSubmission( + master: String, + request: CreateSubmissionRequest): SubmitRestProtocolResponse = { + logInfo(s"Submitting a request to launch an application in $master.") + validateMaster(master) + val url = getSubmitUrl(master) + val response = postJson(url, request.toJson) + response match { + case s: CreateSubmissionResponse => + reportSubmissionStatus(master, s) + handleRestResponse(s) + case unexpected => + handleUnexpectedRestResponse(unexpected) + } + response + } + + /** Request that the server kill the specified submission. */ + def killSubmission(master: String, submissionId: String): SubmitRestProtocolResponse = { + logInfo(s"Submitting a request to kill submission $submissionId in $master.") + validateMaster(master) + val response = post(getKillUrl(master, submissionId)) + response match { + case k: KillSubmissionResponse => handleRestResponse(k) + case unexpected => handleUnexpectedRestResponse(unexpected) + } + response + } + + /** Request the status of a submission from the server. */ + def requestSubmissionStatus( + master: String, + submissionId: String, + quiet: Boolean = false): SubmitRestProtocolResponse = { + logInfo(s"Submitting a request for the status of submission $submissionId in $master.") + validateMaster(master) + val response = get(getStatusUrl(master, submissionId)) + response match { + case s: SubmissionStatusResponse => if (!quiet) { handleRestResponse(s) } + case unexpected => handleUnexpectedRestResponse(unexpected) + } + response + } + + /** Construct a message that captures the specified parameters for submitting an application. */ + def constructSubmitRequest( + appResource: String, + mainClass: String, + appArgs: Array[String], + sparkProperties: Map[String, String], + environmentVariables: Map[String, String]): CreateSubmissionRequest = { + val message = new CreateSubmissionRequest + message.clientSparkVersion = sparkVersion + message.appResource = appResource + message.mainClass = mainClass + message.appArgs = appArgs + message.sparkProperties = sparkProperties + message.environmentVariables = environmentVariables + message.validate() + message + } + + /** Send a GET request to the specified URL. */ + private def get(url: URL): SubmitRestProtocolResponse = { + logDebug(s"Sending GET request to server at $url.") + val conn = url.openConnection().asInstanceOf[HttpURLConnection] + conn.setRequestMethod("GET") + readResponse(conn) + } + + /** Send a POST request to the specified URL. */ + private def post(url: URL): SubmitRestProtocolResponse = { + logDebug(s"Sending POST request to server at $url.") + val conn = url.openConnection().asInstanceOf[HttpURLConnection] + conn.setRequestMethod("POST") + readResponse(conn) + } + + /** Send a POST request with the given JSON as the body to the specified URL. */ + private def postJson(url: URL, json: String): SubmitRestProtocolResponse = { + logDebug(s"Sending POST request to server at $url:\n$json") + val conn = url.openConnection().asInstanceOf[HttpURLConnection] + conn.setRequestMethod("POST") + conn.setRequestProperty("Content-Type", "application/json") + conn.setRequestProperty("charset", "utf-8") + conn.setDoOutput(true) + val out = new DataOutputStream(conn.getOutputStream) + out.write(json.getBytes(Charsets.UTF_8)) + out.close() + readResponse(conn) + } + + /** + * Read the response from the server and return it as a validated [[SubmitRestProtocolResponse]]. + * If the response represents an error, report the embedded message to the user. + * Exposed for testing. + */ + private[rest] def readResponse(connection: HttpURLConnection): SubmitRestProtocolResponse = { + try { + val dataStream = + if (connection.getResponseCode == HttpServletResponse.SC_OK) { + connection.getInputStream + } else { + connection.getErrorStream + } + // If the server threw an exception while writing a response, it will not have a body + if (dataStream == null) { + throw new SubmitRestProtocolException("Server returned empty body") + } + val responseJson = Source.fromInputStream(dataStream).mkString + logDebug(s"Response from the server:\n$responseJson") + val response = SubmitRestProtocolMessage.fromJson(responseJson) + response.validate() + response match { + // If the response is an error, log the message + case error: ErrorResponse => + logError(s"Server responded with error:\n${error.message}") + error + // Otherwise, simply return the response + case response: SubmitRestProtocolResponse => response + case unexpected => + throw new SubmitRestProtocolException( + s"Message received from server was not a response:\n${unexpected.toJson}") + } + } catch { + case unreachable @ (_: FileNotFoundException | _: SocketException) => + throw new SubmitRestConnectionException( + s"Unable to connect to server ${connection.getURL}", unreachable) + case malformed @ (_: JsonProcessingException | _: SubmitRestProtocolException) => + throw new SubmitRestProtocolException( + "Malformed response received from server", malformed) + } + } + + /** Return the REST URL for creating a new submission. */ + private def getSubmitUrl(master: String): URL = { + val baseUrl = getBaseUrl(master) + new URL(s"$baseUrl/create") + } + + /** Return the REST URL for killing an existing submission. */ + private def getKillUrl(master: String, submissionId: String): URL = { + val baseUrl = getBaseUrl(master) + new URL(s"$baseUrl/kill/$submissionId") + } + + /** Return the REST URL for requesting the status of an existing submission. */ + private def getStatusUrl(master: String, submissionId: String): URL = { + val baseUrl = getBaseUrl(master) + new URL(s"$baseUrl/status/$submissionId") + } + + /** Return the base URL for communicating with the server, including the protocol version. */ + private def getBaseUrl(master: String): String = { + val masterUrl = master.stripPrefix("spark://").stripSuffix("/") + s"http://$masterUrl/$PROTOCOL_VERSION/submissions" + } + + /** Throw an exception if this is not standalone mode. */ + private def validateMaster(master: String): Unit = { + if (!master.startsWith("spark://")) { + throw new IllegalArgumentException("This REST client is only supported in standalone mode.") + } + } + + /** Report the status of a newly created submission. */ + private def reportSubmissionStatus( + master: String, + submitResponse: CreateSubmissionResponse): Unit = { + if (submitResponse.success) { + val submissionId = submitResponse.submissionId + if (submissionId != null) { + logInfo(s"Submission successfully created as $submissionId. Polling submission state...") + pollSubmissionStatus(master, submissionId) + } else { + // should never happen + logError("Application successfully submitted, but submission ID was not provided!") + } + } else { + val failMessage = Option(submitResponse.message).map { ": " + _ }.getOrElse("") + logError("Application submission failed" + failMessage) + } + } + + /** + * Poll the status of the specified submission and log it. + * This retries up to a fixed number of times before giving up. + */ + private def pollSubmissionStatus(master: String, submissionId: String): Unit = { + (1 to REPORT_DRIVER_STATUS_MAX_TRIES).foreach { _ => + val response = requestSubmissionStatus(master, submissionId, quiet = true) + val statusResponse = response match { + case s: SubmissionStatusResponse => s + case _ => return // unexpected type, let upstream caller handle it + } + if (statusResponse.success) { + val driverState = Option(statusResponse.driverState) + val workerId = Option(statusResponse.workerId) + val workerHostPort = Option(statusResponse.workerHostPort) + val exception = Option(statusResponse.message) + // Log driver state, if present + driverState match { + case Some(state) => logInfo(s"State of driver $submissionId is now $state.") + case _ => logError(s"State of driver $submissionId was not found!") + } + // Log worker node, if present + (workerId, workerHostPort) match { + case (Some(id), Some(hp)) => logInfo(s"Driver is running on worker $id at $hp.") + case _ => + } + // Log exception stack trace, if present + exception.foreach { e => logError(e) } + return + } + Thread.sleep(REPORT_DRIVER_STATUS_INTERVAL) + } + logError(s"Error: Master did not recognize driver $submissionId.") + } + + /** Log the response sent by the server in the REST application submission protocol. */ + private def handleRestResponse(response: SubmitRestProtocolResponse): Unit = { + logInfo(s"Server responded with ${response.messageType}:\n${response.toJson}") + } + + /** Log an appropriate error if the response sent by the server is not of the expected type. */ + private def handleUnexpectedRestResponse(unexpected: SubmitRestProtocolResponse): Unit = { + logError(s"Error: Server responded with message of unexpected type ${unexpected.messageType}.") + } +} + +private[spark] object StandaloneRestClient { + val REPORT_DRIVER_STATUS_INTERVAL = 1000 + val REPORT_DRIVER_STATUS_MAX_TRIES = 10 + val PROTOCOL_VERSION = "v1" + + /** + * Submit an application, assuming Spark parameters are specified through the given config. + * This is abstracted to its own method for testing purposes. + */ + private[rest] def run( + appResource: String, + mainClass: String, + appArgs: Array[String], + conf: SparkConf, + env: Map[String, String] = sys.env): SubmitRestProtocolResponse = { + val master = conf.getOption("spark.master").getOrElse { + throw new IllegalArgumentException("'spark.master' must be set.") + } + val sparkProperties = conf.getAll.toMap + val environmentVariables = env.filter { case (k, _) => k.startsWith("SPARK_") } + val client = new StandaloneRestClient + val submitRequest = client.constructSubmitRequest( + appResource, mainClass, appArgs, sparkProperties, environmentVariables) + client.createSubmission(master, submitRequest) + } + + def main(args: Array[String]): Unit = { + if (args.size < 2) { + sys.error("Usage: StandaloneRestClient [app resource] [main class] [app args*]") + sys.exit(1) + } + val appResource = args(0) + val mainClass = args(1) + val appArgs = args.slice(2, args.size) + val conf = new SparkConf + run(appResource, mainClass, appArgs, conf) + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala new file mode 100644 index 0000000000000..f9e0478e4f874 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala @@ -0,0 +1,438 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.rest + +import java.io.File +import java.net.InetSocketAddress +import javax.servlet.http.{HttpServlet, HttpServletRequest, HttpServletResponse} + +import scala.io.Source + +import akka.actor.ActorRef +import com.fasterxml.jackson.core.JsonProcessingException +import org.eclipse.jetty.server.Server +import org.eclipse.jetty.servlet.{ServletHolder, ServletContextHandler} +import org.eclipse.jetty.util.thread.QueuedThreadPool +import org.json4s._ +import org.json4s.jackson.JsonMethods._ + +import org.apache.spark.{Logging, SparkConf, SPARK_VERSION => sparkVersion} +import org.apache.spark.util.{AkkaUtils, Utils} +import org.apache.spark.deploy.{Command, DeployMessages, DriverDescription} +import org.apache.spark.deploy.ClientArguments._ + +/** + * A server that responds to requests submitted by the [[StandaloneRestClient]]. + * This is intended to be embedded in the standalone Master and used in cluster mode only. + * + * This server responds with different HTTP codes depending on the situation: + * 200 OK - Request was processed successfully + * 400 BAD REQUEST - Request was malformed, not successfully validated, or of unexpected type + * 468 UNKNOWN PROTOCOL VERSION - Request specified a protocol this server does not understand + * 500 INTERNAL SERVER ERROR - Server throws an exception internally while processing the request + * + * The server always includes a JSON representation of the relevant [[SubmitRestProtocolResponse]] + * in the HTTP body. If an error occurs, however, the server will include an [[ErrorResponse]] + * instead of the one expected by the client. If the construction of this error response itself + * fails, the response will consist of an empty body with a response code that indicates internal + * server error. + * + * @param host the address this server should bind to + * @param requestedPort the port this server will attempt to bind to + * @param masterActor reference to the Master actor to which requests can be sent + * @param masterUrl the URL of the Master new drivers will attempt to connect to + * @param masterConf the conf used by the Master + */ +private[spark] class StandaloneRestServer( + host: String, + requestedPort: Int, + masterActor: ActorRef, + masterUrl: String, + masterConf: SparkConf) + extends Logging { + + import StandaloneRestServer._ + + private var _server: Option[Server] = None + + // A mapping from URL prefixes to servlets that serve them. Exposed for testing. + protected val baseContext = s"/$PROTOCOL_VERSION/submissions" + protected val contextToServlet = Map[String, StandaloneRestServlet]( + s"$baseContext/create/*" -> new SubmitRequestServlet(masterActor, masterUrl, masterConf), + s"$baseContext/kill/*" -> new KillRequestServlet(masterActor, masterConf), + s"$baseContext/status/*" -> new StatusRequestServlet(masterActor, masterConf), + "/*" -> new ErrorServlet // default handler + ) + + /** Start the server and return the bound port. */ + def start(): Int = { + val (server, boundPort) = Utils.startServiceOnPort[Server](requestedPort, doStart, masterConf) + _server = Some(server) + logInfo(s"Started REST server for submitting applications on port $boundPort") + boundPort + } + + /** + * Map the servlets to their corresponding contexts and attach them to a server. + * Return a 2-tuple of the started server and the bound port. + */ + private def doStart(startPort: Int): (Server, Int) = { + val server = new Server(new InetSocketAddress(host, startPort)) + val threadPool = new QueuedThreadPool + threadPool.setDaemon(true) + server.setThreadPool(threadPool) + val mainHandler = new ServletContextHandler + mainHandler.setContextPath("/") + contextToServlet.foreach { case (prefix, servlet) => + mainHandler.addServlet(new ServletHolder(servlet), prefix) + } + server.setHandler(mainHandler) + server.start() + val boundPort = server.getConnectors()(0).getLocalPort + (server, boundPort) + } + + def stop(): Unit = { + _server.foreach(_.stop()) + } +} + +private[rest] object StandaloneRestServer { + val PROTOCOL_VERSION = StandaloneRestClient.PROTOCOL_VERSION + val SC_UNKNOWN_PROTOCOL_VERSION = 468 +} + +/** + * An abstract servlet for handling requests passed to the [[StandaloneRestServer]]. + */ +private[rest] abstract class StandaloneRestServlet extends HttpServlet with Logging { + + /** + * Serialize the given response message to JSON and send it through the response servlet. + * This validates the response before sending it to ensure it is properly constructed. + */ + protected def sendResponse( + responseMessage: SubmitRestProtocolResponse, + responseServlet: HttpServletResponse): Unit = { + val message = validateResponse(responseMessage, responseServlet) + responseServlet.setContentType("application/json") + responseServlet.setCharacterEncoding("utf-8") + responseServlet.getWriter.write(message.toJson) + } + + /** + * Return any fields in the client request message that the server does not know about. + * + * The mechanism for this is to reconstruct the JSON on the server side and compare the + * diff between this JSON and the one generated on the client side. Any fields that are + * only in the client JSON are treated as unexpected. + */ + protected def findUnknownFields( + requestJson: String, + requestMessage: SubmitRestProtocolMessage): Array[String] = { + val clientSideJson = parse(requestJson) + val serverSideJson = parse(requestMessage.toJson) + val Diff(_, _, unknown) = clientSideJson.diff(serverSideJson) + unknown match { + case j: JObject => j.obj.map { case (k, _) => k }.toArray + case _ => Array.empty[String] // No difference + } + } + + /** Return a human readable String representation of the exception. */ + protected def formatException(e: Throwable): String = { + val stackTraceString = e.getStackTrace.map { "\t" + _ }.mkString("\n") + s"$e\n$stackTraceString" + } + + /** Construct an error message to signal the fact that an exception has been thrown. */ + protected def handleError(message: String): ErrorResponse = { + val e = new ErrorResponse + e.serverSparkVersion = sparkVersion + e.message = message + e + } + + /** + * Parse a submission ID from the relative path, assuming it is the first part of the path. + * For instance, we expect the path to take the form /[submission ID]/maybe/something/else. + * The returned submission ID cannot be empty. If the path is unexpected, return None. + */ + protected def parseSubmissionId(path: String): Option[String] = { + if (path == null || path.isEmpty) { + None + } else { + path.stripPrefix("/").split("/").headOption.filter(_.nonEmpty) + } + } + + /** + * Validate the response to ensure that it is correctly constructed. + * + * If it is, simply return the message as is. Otherwise, return an error response instead + * to propagate the exception back to the client and set the appropriate error code. + */ + private def validateResponse( + responseMessage: SubmitRestProtocolResponse, + responseServlet: HttpServletResponse): SubmitRestProtocolResponse = { + try { + responseMessage.validate() + responseMessage + } catch { + case e: Exception => + responseServlet.setStatus(HttpServletResponse.SC_INTERNAL_SERVER_ERROR) + handleError("Internal server error: " + formatException(e)) + } + } +} + +/** + * A servlet for handling kill requests passed to the [[StandaloneRestServer]]. + */ +private[rest] class KillRequestServlet(masterActor: ActorRef, conf: SparkConf) + extends StandaloneRestServlet { + + /** + * If a submission ID is specified in the URL, have the Master kill the corresponding + * driver and return an appropriate response to the client. Otherwise, return error. + */ + protected override def doPost( + request: HttpServletRequest, + response: HttpServletResponse): Unit = { + val submissionId = parseSubmissionId(request.getPathInfo) + val responseMessage = submissionId.map(handleKill).getOrElse { + response.setStatus(HttpServletResponse.SC_BAD_REQUEST) + handleError("Submission ID is missing in kill request.") + } + sendResponse(responseMessage, response) + } + + protected def handleKill(submissionId: String): KillSubmissionResponse = { + val askTimeout = AkkaUtils.askTimeout(conf) + val response = AkkaUtils.askWithReply[DeployMessages.KillDriverResponse]( + DeployMessages.RequestKillDriver(submissionId), masterActor, askTimeout) + val k = new KillSubmissionResponse + k.serverSparkVersion = sparkVersion + k.message = response.message + k.submissionId = submissionId + k.success = response.success + k + } +} + +/** + * A servlet for handling status requests passed to the [[StandaloneRestServer]]. + */ +private[rest] class StatusRequestServlet(masterActor: ActorRef, conf: SparkConf) + extends StandaloneRestServlet { + + /** + * If a submission ID is specified in the URL, request the status of the corresponding + * driver from the Master and include it in the response. Otherwise, return error. + */ + protected override def doGet( + request: HttpServletRequest, + response: HttpServletResponse): Unit = { + val submissionId = parseSubmissionId(request.getPathInfo) + val responseMessage = submissionId.map(handleStatus).getOrElse { + response.setStatus(HttpServletResponse.SC_BAD_REQUEST) + handleError("Submission ID is missing in status request.") + } + sendResponse(responseMessage, response) + } + + protected def handleStatus(submissionId: String): SubmissionStatusResponse = { + val askTimeout = AkkaUtils.askTimeout(conf) + val response = AkkaUtils.askWithReply[DeployMessages.DriverStatusResponse]( + DeployMessages.RequestDriverStatus(submissionId), masterActor, askTimeout) + val message = response.exception.map { s"Exception from the cluster:\n" + formatException(_) } + val d = new SubmissionStatusResponse + d.serverSparkVersion = sparkVersion + d.submissionId = submissionId + d.success = response.found + d.driverState = response.state.map(_.toString).orNull + d.workerId = response.workerId.orNull + d.workerHostPort = response.workerHostPort.orNull + d.message = message.orNull + d + } +} + +/** + * A servlet for handling submit requests passed to the [[StandaloneRestServer]]. + */ +private[rest] class SubmitRequestServlet( + masterActor: ActorRef, + masterUrl: String, + conf: SparkConf) + extends StandaloneRestServlet { + + /** + * Submit an application to the Master with parameters specified in the request. + * + * The request is assumed to be a [[SubmitRestProtocolRequest]] in the form of JSON. + * If the request is successfully processed, return an appropriate response to the + * client indicating so. Otherwise, return error instead. + */ + protected override def doPost( + requestServlet: HttpServletRequest, + responseServlet: HttpServletResponse): Unit = { + val responseMessage = + try { + val requestMessageJson = Source.fromInputStream(requestServlet.getInputStream).mkString + val requestMessage = SubmitRestProtocolMessage.fromJson(requestMessageJson) + // The response should have already been validated on the client. + // In case this is not true, validate it ourselves to avoid potential NPEs. + requestMessage.validate() + handleSubmit(requestMessageJson, requestMessage, responseServlet) + } catch { + // The client failed to provide a valid JSON, so this is not our fault + case e @ (_: JsonProcessingException | _: SubmitRestProtocolException) => + responseServlet.setStatus(HttpServletResponse.SC_BAD_REQUEST) + handleError("Malformed request: " + formatException(e)) + } + sendResponse(responseMessage, responseServlet) + } + + /** + * Handle the submit request and construct an appropriate response to return to the client. + * + * This assumes that the request message is already successfully validated. + * If the request message is not of the expected type, return error to the client. + */ + private def handleSubmit( + requestMessageJson: String, + requestMessage: SubmitRestProtocolMessage, + responseServlet: HttpServletResponse): SubmitRestProtocolResponse = { + requestMessage match { + case submitRequest: CreateSubmissionRequest => + val askTimeout = AkkaUtils.askTimeout(conf) + val driverDescription = buildDriverDescription(submitRequest) + val response = AkkaUtils.askWithReply[DeployMessages.SubmitDriverResponse]( + DeployMessages.RequestSubmitDriver(driverDescription), masterActor, askTimeout) + val submitResponse = new CreateSubmissionResponse + submitResponse.serverSparkVersion = sparkVersion + submitResponse.message = response.message + submitResponse.success = response.success + submitResponse.submissionId = response.driverId.orNull + val unknownFields = findUnknownFields(requestMessageJson, requestMessage) + if (unknownFields.nonEmpty) { + // If there are fields that the server does not know about, warn the client + submitResponse.unknownFields = unknownFields + } + submitResponse + case unexpected => + responseServlet.setStatus(HttpServletResponse.SC_BAD_REQUEST) + handleError(s"Received message of unexpected type ${unexpected.messageType}.") + } + } + + /** + * Build a driver description from the fields specified in the submit request. + * + * This involves constructing a command that takes into account memory, java options, + * classpath and other settings to launch the driver. This does not currently consider + * fields used by python applications since python is not supported in standalone + * cluster mode yet. + */ + private def buildDriverDescription(request: CreateSubmissionRequest): DriverDescription = { + // Required fields, including the main class because python is not yet supported + val appResource = Option(request.appResource).getOrElse { + throw new SubmitRestMissingFieldException("Application jar is missing.") + } + val mainClass = Option(request.mainClass).getOrElse { + throw new SubmitRestMissingFieldException("Main class is missing.") + } + + // Optional fields + val sparkProperties = request.sparkProperties + val driverMemory = sparkProperties.get("spark.driver.memory") + val driverCores = sparkProperties.get("spark.driver.cores") + val driverExtraJavaOptions = sparkProperties.get("spark.driver.extraJavaOptions") + val driverExtraClassPath = sparkProperties.get("spark.driver.extraClassPath") + val driverExtraLibraryPath = sparkProperties.get("spark.driver.extraLibraryPath") + val superviseDriver = sparkProperties.get("spark.driver.supervise") + val appArgs = request.appArgs + val environmentVariables = request.environmentVariables + + // Construct driver description + val conf = new SparkConf(false) + .setAll(sparkProperties) + .set("spark.master", masterUrl) + val extraClassPath = driverExtraClassPath.toSeq.flatMap(_.split(File.pathSeparator)) + val extraLibraryPath = driverExtraLibraryPath.toSeq.flatMap(_.split(File.pathSeparator)) + val extraJavaOpts = driverExtraJavaOptions.map(Utils.splitCommandString).getOrElse(Seq.empty) + val sparkJavaOpts = Utils.sparkJavaOpts(conf) + val javaOpts = sparkJavaOpts ++ extraJavaOpts + val command = new Command( + "org.apache.spark.deploy.worker.DriverWrapper", + Seq("{{WORKER_URL}}", "{{USER_JAR}}", mainClass) ++ appArgs, // args to the DriverWrapper + environmentVariables, extraClassPath, extraLibraryPath, javaOpts) + val actualDriverMemory = driverMemory.map(Utils.memoryStringToMb).getOrElse(DEFAULT_MEMORY) + val actualDriverCores = driverCores.map(_.toInt).getOrElse(DEFAULT_CORES) + val actualSuperviseDriver = superviseDriver.map(_.toBoolean).getOrElse(DEFAULT_SUPERVISE) + new DriverDescription( + appResource, actualDriverMemory, actualDriverCores, actualSuperviseDriver, command) + } +} + +/** + * A default servlet that handles error cases that are not captured by other servlets. + */ +private class ErrorServlet extends StandaloneRestServlet { + private val serverVersion = StandaloneRestServer.PROTOCOL_VERSION + + /** Service a faulty request by returning an appropriate error message to the client. */ + protected override def service( + request: HttpServletRequest, + response: HttpServletResponse): Unit = { + val path = request.getPathInfo + val parts = path.stripPrefix("/").split("/").filter(_.nonEmpty).toList + var versionMismatch = false + var msg = + parts match { + case Nil => + // http://host:port/ + "Missing protocol version." + case `serverVersion` :: Nil => + // http://host:port/correct-version + "Missing the /submissions prefix." + case `serverVersion` :: "submissions" :: tail => + // http://host:port/correct-version/submissions/* + "Missing an action: please specify one of /create, /kill, or /status." + case unknownVersion :: tail => + // http://host:port/unknown-version/* + versionMismatch = true + s"Unknown protocol version '$unknownVersion'." + case _ => + // never reached + s"Malformed path $path." + } + msg += s" Please submit requests through http://[host]:[port]/$serverVersion/submissions/..." + val error = handleError(msg) + // If there is a version mismatch, include the highest protocol version that + // this server supports in case the client wants to retry with our version + if (versionMismatch) { + error.highestProtocolVersion = serverVersion + response.setStatus(StandaloneRestServer.SC_UNKNOWN_PROTOCOL_VERSION) + } else { + response.setStatus(HttpServletResponse.SC_BAD_REQUEST) + } + sendResponse(error, response) + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolException.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolException.scala new file mode 100644 index 0000000000000..d7a0bdbe10778 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolException.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.rest + +/** + * An exception thrown in the REST application submission protocol. + */ +private[spark] class SubmitRestProtocolException(message: String, cause: Throwable = null) + extends Exception(message, cause) + +/** + * An exception thrown if a field is missing from a [[SubmitRestProtocolMessage]]. + */ +private[spark] class SubmitRestMissingFieldException(message: String) + extends SubmitRestProtocolException(message) + +/** + * An exception thrown if the REST client cannot reach the REST server. + */ +private[spark] class SubmitRestConnectionException(message: String, cause: Throwable) + extends SubmitRestProtocolException(message, cause) diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolMessage.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolMessage.scala new file mode 100644 index 0000000000000..8f36635674a28 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolMessage.scala @@ -0,0 +1,146 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.rest + +import com.fasterxml.jackson.annotation._ +import com.fasterxml.jackson.annotation.JsonAutoDetect.Visibility +import com.fasterxml.jackson.annotation.JsonInclude.Include +import com.fasterxml.jackson.databind.{DeserializationFeature, ObjectMapper, SerializationFeature} +import com.fasterxml.jackson.module.scala.DefaultScalaModule +import org.json4s.JsonAST._ +import org.json4s.jackson.JsonMethods._ + +import org.apache.spark.util.Utils + +/** + * An abstract message exchanged in the REST application submission protocol. + * + * This message is intended to be serialized to and deserialized from JSON in the exchange. + * Each message can either be a request or a response and consists of three common fields: + * (1) the action, which fully specifies the type of the message + * (2) the Spark version of the client / server + * (3) an optional message + */ +@JsonInclude(Include.NON_NULL) +@JsonAutoDetect(getterVisibility = Visibility.ANY, setterVisibility = Visibility.ANY) +@JsonPropertyOrder(alphabetic = true) +private[spark] abstract class SubmitRestProtocolMessage { + @JsonIgnore + val messageType = Utils.getFormattedClassName(this) + + val action: String = messageType + var message: String = null + + // For JSON deserialization + private def setAction(a: String): Unit = { } + + /** + * Serialize the message to JSON. + * This also ensures that the message is valid and its fields are in the expected format. + */ + def toJson: String = { + validate() + SubmitRestProtocolMessage.mapper.writeValueAsString(this) + } + + /** + * Assert the validity of the message. + * If the validation fails, throw a [[SubmitRestProtocolException]]. + */ + final def validate(): Unit = { + try { + doValidate() + } catch { + case e: Exception => + throw new SubmitRestProtocolException(s"Validation of message $messageType failed!", e) + } + } + + /** Assert the validity of the message */ + protected def doValidate(): Unit = { + if (action == null) { + throw new SubmitRestMissingFieldException(s"The action field is missing in $messageType") + } + } + + /** Assert that the specified field is set in this message. */ + protected def assertFieldIsSet[T](value: T, name: String): Unit = { + if (value == null) { + throw new SubmitRestMissingFieldException(s"'$name' is missing in message $messageType.") + } + } + + /** + * Assert a condition when validating this message. + * If the assertion fails, throw a [[SubmitRestProtocolException]]. + */ + protected def assert(condition: Boolean, failMessage: String): Unit = { + if (!condition) { throw new SubmitRestProtocolException(failMessage) } + } +} + +/** + * Helper methods to process serialized [[SubmitRestProtocolMessage]]s. + */ +private[spark] object SubmitRestProtocolMessage { + private val packagePrefix = this.getClass.getPackage.getName + private val mapper = new ObjectMapper() + .configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false) + .enable(SerializationFeature.INDENT_OUTPUT) + .registerModule(DefaultScalaModule) + + /** + * Parse the value of the action field from the given JSON. + * If the action field is not found, throw a [[SubmitRestMissingFieldException]]. + */ + def parseAction(json: String): String = { + val value: Option[String] = parse(json) match { + case JObject(fields) => + fields.collectFirst { case ("action", v) => v }.collect { case JString(s) => s } + case _ => None + } + value.getOrElse { + throw new SubmitRestMissingFieldException(s"Action field not found in JSON:\n$json") + } + } + + /** + * Construct a [[SubmitRestProtocolMessage]] from its JSON representation. + * + * This method first parses the action from the JSON and uses it to infer the message type. + * Note that the action must represent one of the [[SubmitRestProtocolMessage]]s defined in + * this package. Otherwise, a [[ClassNotFoundException]] will be thrown. + */ + def fromJson(json: String): SubmitRestProtocolMessage = { + val className = parseAction(json) + val clazz = Class.forName(packagePrefix + "." + className) + .asSubclass[SubmitRestProtocolMessage](classOf[SubmitRestProtocolMessage]) + fromJson(json, clazz) + } + + /** + * Construct a [[SubmitRestProtocolMessage]] from its JSON representation. + * + * This method determines the type of the message from the class provided instead of + * inferring it from the action field. This is useful for deserializing JSON that + * represents custom user-defined messages. + */ + def fromJson[T <: SubmitRestProtocolMessage](json: String, clazz: Class[T]): T = { + mapper.readValue(json, clazz) + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolRequest.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolRequest.scala new file mode 100644 index 0000000000000..9e1fd8c40cabd --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolRequest.scala @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.rest + +import scala.util.Try + +import org.apache.spark.util.Utils + +/** + * An abstract request sent from the client in the REST application submission protocol. + */ +private[spark] abstract class SubmitRestProtocolRequest extends SubmitRestProtocolMessage { + var clientSparkVersion: String = null + protected override def doValidate(): Unit = { + super.doValidate() + assertFieldIsSet(clientSparkVersion, "clientSparkVersion") + } +} + +/** + * A request to launch a new application in the REST application submission protocol. + */ +private[spark] class CreateSubmissionRequest extends SubmitRestProtocolRequest { + var appResource: String = null + var mainClass: String = null + var appArgs: Array[String] = null + var sparkProperties: Map[String, String] = null + var environmentVariables: Map[String, String] = null + + protected override def doValidate(): Unit = { + super.doValidate() + assert(sparkProperties != null, "No Spark properties set!") + assertFieldIsSet(appResource, "appResource") + assertPropertyIsSet("spark.app.name") + assertPropertyIsBoolean("spark.driver.supervise") + assertPropertyIsNumeric("spark.driver.cores") + assertPropertyIsNumeric("spark.cores.max") + assertPropertyIsMemory("spark.driver.memory") + assertPropertyIsMemory("spark.executor.memory") + } + + private def assertPropertyIsSet(key: String): Unit = + assertFieldIsSet(sparkProperties.getOrElse(key, null), key) + + private def assertPropertyIsBoolean(key: String): Unit = + assertProperty[Boolean](key, "boolean", _.toBoolean) + + private def assertPropertyIsNumeric(key: String): Unit = + assertProperty[Int](key, "numeric", _.toInt) + + private def assertPropertyIsMemory(key: String): Unit = + assertProperty[Int](key, "memory", Utils.memoryStringToMb) + + /** Assert that a Spark property can be converted to a certain type. */ + private def assertProperty[T](key: String, valueType: String, convert: (String => T)): Unit = { + sparkProperties.get(key).foreach { value => + Try(convert(value)).getOrElse { + throw new SubmitRestProtocolException( + s"Property '$key' expected $valueType value: actual was '$value'.") + } + } + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolResponse.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolResponse.scala new file mode 100644 index 0000000000000..16dfe041d4bea --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolResponse.scala @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.rest + +import java.lang.Boolean + +/** + * An abstract response sent from the server in the REST application submission protocol. + */ +private[spark] abstract class SubmitRestProtocolResponse extends SubmitRestProtocolMessage { + var serverSparkVersion: String = null + var success: Boolean = null + var unknownFields: Array[String] = null + protected override def doValidate(): Unit = { + super.doValidate() + assertFieldIsSet(serverSparkVersion, "serverSparkVersion") + } +} + +/** + * A response to a [[CreateSubmissionRequest]] in the REST application submission protocol. + */ +private[spark] class CreateSubmissionResponse extends SubmitRestProtocolResponse { + var submissionId: String = null + protected override def doValidate(): Unit = { + super.doValidate() + assertFieldIsSet(success, "success") + } +} + +/** + * A response to a kill request in the REST application submission protocol. + */ +private[spark] class KillSubmissionResponse extends SubmitRestProtocolResponse { + var submissionId: String = null + protected override def doValidate(): Unit = { + super.doValidate() + assertFieldIsSet(submissionId, "submissionId") + assertFieldIsSet(success, "success") + } +} + +/** + * A response to a status request in the REST application submission protocol. + */ +private[spark] class SubmissionStatusResponse extends SubmitRestProtocolResponse { + var submissionId: String = null + var driverState: String = null + var workerId: String = null + var workerHostPort: String = null + + protected override def doValidate(): Unit = { + super.doValidate() + assertFieldIsSet(submissionId, "submissionId") + assertFieldIsSet(success, "success") + } +} + +/** + * An error response message used in the REST application submission protocol. + */ +private[spark] class ErrorResponse extends SubmitRestProtocolResponse { + // The highest protocol version that the server knows about + // This is set when the client specifies an unknown version + var highestProtocolVersion: String = null + protected override def doValidate(): Unit = { + super.doValidate() + assertFieldIsSet(message, "message") + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala b/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala index 28e9662db5da9..3e013c32096c5 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala @@ -115,9 +115,19 @@ object CommandUtils extends Logging { val userClassPath = command.classPathEntries ++ Seq(classPath) val javaVersion = System.getProperty("java.version") - val permGenOpt = if (!javaVersion.startsWith("1.8")) Some("-XX:MaxPermSize=128m") else None + + val javaOpts = workerLocalOpts ++ command.javaOpts + + val permGenOpt = + if (!javaVersion.startsWith("1.8") && !javaOpts.exists(_.startsWith("-XX:MaxPermSize="))) { + // do not specify -XX:MaxPermSize if it was already specified by user + Some("-XX:MaxPermSize=128m") + } else { + None + } + Seq("-cp", userClassPath.filterNot(_.isEmpty).mkString(File.pathSeparator)) ++ - permGenOpt ++ workerLocalOpts ++ command.javaOpts ++ memoryOpts + permGenOpt ++ javaOpts ++ memoryOpts } /** Spawn a thread that will redirect a given stream to a file */ diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala index 28cab36c7b9e2..e16bccb24d2c4 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala @@ -20,19 +20,18 @@ package org.apache.spark.deploy.worker import java.io._ import scala.collection.JavaConversions._ -import scala.collection.Map import akka.actor.ActorRef import com.google.common.base.Charsets.UTF_8 import com.google.common.io.Files -import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileUtil, Path} import org.apache.spark.{Logging, SparkConf} -import org.apache.spark.deploy.{Command, DriverDescription, SparkHadoopUtil} +import org.apache.spark.deploy.{DriverDescription, SparkHadoopUtil} import org.apache.spark.deploy.DeployMessages.DriverStateChanged import org.apache.spark.deploy.master.DriverState import org.apache.spark.deploy.master.DriverState.DriverState +import org.apache.spark.util.{Clock, SystemClock} /** * Manages the execution of one driver, including automatically restarting the driver on failure. @@ -59,9 +58,7 @@ private[spark] class DriverRunner( // Decoupled for testing private[deploy] def setClock(_clock: Clock) = clock = _clock private[deploy] def setSleeper(_sleeper: Sleeper) = sleeper = _sleeper - private var clock = new Clock { - def currentTimeMillis(): Long = System.currentTimeMillis() - } + private var clock: Clock = new SystemClock() private var sleeper = new Sleeper { def sleep(seconds: Int): Unit = (0 until seconds).takeWhile(f => {Thread.sleep(1000); !killed}) } @@ -74,10 +71,15 @@ private[spark] class DriverRunner( val driverDir = createWorkingDirectory() val localJarFilename = downloadUserJar(driverDir) - // Make sure user application jar is on the classpath + def substituteVariables(argument: String): String = argument match { + case "{{WORKER_URL}}" => workerUrl + case "{{USER_JAR}}" => localJarFilename + case other => other + } + // TODO: If we add ability to submit multiple jars they should also be added here val builder = CommandUtils.buildProcessBuilder(driverDesc.command, driverDesc.mem, - sparkHome.getAbsolutePath, substituteVariables, Seq(localJarFilename)) + sparkHome.getAbsolutePath, substituteVariables) launchDriver(builder, driverDir, driverDesc.supervise) } catch { @@ -111,12 +113,6 @@ private[spark] class DriverRunner( } } - /** Replace variables in a command argument passed to us */ - private def substituteVariables(argument: String): String = argument match { - case "{{WORKER_URL}}" => workerUrl - case other => other - } - /** * Creates the working directory for this driver. * Will throw an exception if there are errors preparing the directory. @@ -191,9 +187,9 @@ private[spark] class DriverRunner( initialize(process.get) } - val processStart = clock.currentTimeMillis() + val processStart = clock.getTimeMillis() val exitCode = process.get.waitFor() - if (clock.currentTimeMillis() - processStart > successfulRunDuration * 1000) { + if (clock.getTimeMillis() - processStart > successfulRunDuration * 1000) { waitSeconds = 1 } @@ -209,10 +205,6 @@ private[spark] class DriverRunner( } } -private[deploy] trait Clock { - def currentTimeMillis(): Long -} - private[deploy] trait Sleeper { def sleep(seconds: Int) } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala b/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala index 05e242e6df702..deef6ef9043c6 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala @@ -17,32 +17,51 @@ package org.apache.spark.deploy.worker +import java.io.File + import akka.actor._ import org.apache.spark.{SecurityManager, SparkConf} -import org.apache.spark.util.{AkkaUtils, Utils} +import org.apache.spark.util.{AkkaUtils, ChildFirstURLClassLoader, MutableURLClassLoader, Utils} /** * Utility object for launching driver programs such that they share fate with the Worker process. + * This is used in standalone cluster mode only. */ object DriverWrapper { def main(args: Array[String]) { args.toList match { - case workerUrl :: mainClass :: extraArgs => + /* + * IMPORTANT: Spark 1.3 provides a stable application submission gateway that is both + * backward and forward compatible across future Spark versions. Because this gateway + * uses this class to launch the driver, the ordering and semantics of the arguments + * here must also remain consistent across versions. + */ + case workerUrl :: userJar :: mainClass :: extraArgs => val conf = new SparkConf() val (actorSystem, _) = AkkaUtils.createActorSystem("Driver", Utils.localHostName(), 0, conf, new SecurityManager(conf)) actorSystem.actorOf(Props(classOf[WorkerWatcher], workerUrl), name = "workerWatcher") + val currentLoader = Thread.currentThread.getContextClassLoader + val userJarUrl = new File(userJar).toURI().toURL() + val loader = + if (sys.props.getOrElse("spark.driver.userClassPathFirst", "false").toBoolean) { + new ChildFirstURLClassLoader(Array(userJarUrl), currentLoader) + } else { + new MutableURLClassLoader(Array(userJarUrl), currentLoader) + } + Thread.currentThread.setContextClassLoader(loader) + // Delegate to supplied main class - val clazz = Class.forName(args(1)) + val clazz = Class.forName(mainClass, true, loader) val mainMethod = clazz.getMethod("main", classOf[Array[String]]) mainMethod.invoke(null, extraArgs.toArray[String]) actorSystem.shutdown() case _ => - System.err.println("Usage: DriverWrapper [options]") + System.err.println("Usage: DriverWrapper [options]") System.exit(-1) } } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala index acbdf0d8bd7bc..6653aca0a0f06 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala @@ -26,7 +26,7 @@ import com.google.common.base.Charsets.UTF_8 import com.google.common.io.Files import org.apache.spark.{SparkConf, Logging} -import org.apache.spark.deploy.{ApplicationDescription, Command, ExecutorState} +import org.apache.spark.deploy.{ApplicationDescription, ExecutorState} import org.apache.spark.deploy.DeployMessages.ExecutorStateChanged import org.apache.spark.util.logging.FileAppender @@ -43,6 +43,7 @@ private[spark] class ExecutorRunner( val worker: ActorRef, val workerId: String, val host: String, + val webUiPort: Int, val sparkHome: File, val executorDir: File, val workerUrl: String, @@ -104,7 +105,11 @@ private[spark] class ExecutorRunner( workerThread.interrupt() workerThread = null state = ExecutorState.KILLED - Runtime.getRuntime.removeShutdownHook(shutdownHook) + try { + Runtime.getRuntime.removeShutdownHook(shutdownHook) + } catch { + case e: IllegalStateException => None + } } } @@ -130,10 +135,16 @@ private[spark] class ExecutorRunner( logInfo("Launch command: " + command.mkString("\"", "\" \"", "\"")) builder.directory(executorDir) - builder.environment.put("SPARK_LOCAL_DIRS", appLocalDirs.mkString(",")) + builder.environment.put("SPARK_EXECUTOR_DIRS", appLocalDirs.mkString(File.pathSeparator)) // In case we are running this from within the Spark Shell, avoid creating a "scala" // parent process for the executor command builder.environment.put("SPARK_LAUNCH_WITH_SCALA", "0") + + // Add webUI log urls + val baseUrl = s"http://$host:$webUiPort/logPage/?appId=$appId&executorId=$execId&logType=" + builder.environment.put("SPARK_LOG_URL_STDERR", s"${baseUrl}stderr") + builder.environment.put("SPARK_LOG_URL_STDOUT", s"${baseUrl}stdout") + process = builder.start() val header = "Spark Executor Command: %s\n%s\n\n".format( command.mkString("\"", "\" \"", "\""), "=" * 40) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index 13599830123d0..2473a90aa9309 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -31,8 +31,8 @@ import scala.util.Random import akka.actor._ import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent} -import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException} -import org.apache.spark.deploy.{ExecutorDescription, ExecutorState} +import org.apache.spark.{Logging, SecurityManager, SparkConf} +import org.apache.spark.deploy.{Command, ExecutorDescription, ExecutorState} import org.apache.spark.deploy.DeployMessages._ import org.apache.spark.deploy.master.{DriverState, Master} import org.apache.spark.deploy.worker.ui.WorkerWebUI @@ -93,7 +93,12 @@ private[spark] class Worker( var masterAddress: Address = null var activeMasterUrl: String = "" var activeMasterWebUiUrl : String = "" - val akkaUrl = "akka.tcp://%s@%s:%s/user/%s".format(actorSystemName, host, port, actorName) + val akkaUrl = AkkaUtils.address( + AkkaUtils.protocol(context.system), + actorSystemName, + host, + port, + actorName) @volatile var registered = false @volatile var connected = false val workerId = generateWorkerId() @@ -174,8 +179,9 @@ private[spark] class Worker( // activeMasterUrl it's a valid Spark url since we receive it from master. activeMasterUrl = url activeMasterWebUiUrl = uiUrl - master = context.actorSelection(Master.toAkkaUrl(activeMasterUrl)) - masterAddress = Master.toAkkaAddress(activeMasterUrl) + master = context.actorSelection( + Master.toAkkaUrl(activeMasterUrl, AkkaUtils.protocol(context.system))) + masterAddress = Master.toAkkaAddress(activeMasterUrl, AkkaUtils.protocol(context.system)) connected = true // Cancel any outstanding re-registration attempts because we found a new master registrationRetryTimer.foreach(_.cancel()) @@ -339,18 +345,29 @@ private[spark] class Worker( } // Create local dirs for the executor. These are passed to the executor via the - // SPARK_LOCAL_DIRS environment variable, and deleted by the Worker when the + // SPARK_EXECUTOR_DIRS environment variable, and deleted by the Worker when the // application finishes. val appLocalDirs = appDirectories.get(appId).getOrElse { Utils.getOrCreateLocalRootDirs(conf).map { dir => - Utils.createDirectory(dir).getAbsolutePath() + Utils.createDirectory(dir, namePrefix = "executor").getAbsolutePath() }.toSeq } appDirectories(appId) = appLocalDirs - - val manager = new ExecutorRunner(appId, execId, appDesc, cores_, memory_, - self, workerId, host, sparkHome, executorDir, akkaUrl, conf, appLocalDirs, - ExecutorState.LOADING) + val manager = new ExecutorRunner( + appId, + execId, + appDesc.copy(command = Worker.maybeUpdateSSLSettings(appDesc.command, conf)), + cores_, + memory_, + self, + workerId, + host, + webUiPort, + sparkHome, + executorDir, + akkaUrl, + conf, + appLocalDirs, ExecutorState.LOADING) executors(appId + "/" + execId) = manager manager.start() coresUsed += cores_ @@ -406,7 +423,14 @@ private[spark] class Worker( case LaunchDriver(driverId, driverDesc) => { logInfo(s"Asked to launch driver $driverId") - val driver = new DriverRunner(conf, driverId, workDir, sparkHome, driverDesc, self, akkaUrl) + val driver = new DriverRunner( + conf, + driverId, + workDir, + sparkHome, + driverDesc.copy(command = Worker.maybeUpdateSSLSettings(driverDesc.command, conf)), + self, + akkaUrl) drivers(driverId) = driver driver.start() @@ -523,10 +547,32 @@ private[spark] object Worker extends Logging { val securityMgr = new SecurityManager(conf) val (actorSystem, boundPort) = AkkaUtils.createActorSystem(systemName, host, port, conf = conf, securityManager = securityMgr) - val masterAkkaUrls = masterUrls.map(Master.toAkkaUrl) + val masterAkkaUrls = masterUrls.map(Master.toAkkaUrl(_, AkkaUtils.protocol(actorSystem))) actorSystem.actorOf(Props(classOf[Worker], host, boundPort, webUiPort, cores, memory, masterAkkaUrls, systemName, actorName, workDir, conf, securityMgr), name = actorName) (actorSystem, boundPort) } + private[spark] def isUseLocalNodeSSLConfig(cmd: Command): Boolean = { + val pattern = """\-Dspark\.ssl\.useNodeLocalConf\=(.+)""".r + val result = cmd.javaOpts.collectFirst { + case pattern(_result) => _result.toBoolean + } + result.getOrElse(false) + } + + private[spark] def maybeUpdateSSLSettings(cmd: Command, conf: SparkConf): Command = { + val prefix = "spark.ssl." + val useNLC = "spark.ssl.useNodeLocalConf" + if (isUseLocalNodeSSLConfig(cmd)) { + val newJavaOpts = cmd.javaOpts + .filter(opt => !opt.startsWith(s"-D$prefix")) ++ + conf.getAll.collect { case (key, value) if key.startsWith(prefix) => s"-D$key=$value" } :+ + s"-D$useNLC=true" + cmd.copy(javaOpts = newJavaOpts) + } else { + cmd + } + } + } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala index 327b905032800..720f13bfa829b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala @@ -134,7 +134,7 @@ private[spark] class WorkerPage(parent: WorkerWebUI) extends WebUIPage("") { def driverRow(driver: DriverRunner): Seq[Node] = { {driver.driverId} - {driver.driverDesc.command.arguments(1)} + {driver.driverDesc.command.arguments(2)} {driver.finalState.getOrElse(DriverState.RUNNING)} {driver.driverDesc.cores.toString} diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index 9a4adfbbb3d71..dd19e4947db1e 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -17,8 +17,10 @@ package org.apache.spark.executor +import java.net.URL import java.nio.ByteBuffer +import scala.collection.mutable import scala.concurrent.Await import akka.actor.{Actor, ActorSelection, Props} @@ -38,6 +40,7 @@ private[spark] class CoarseGrainedExecutorBackend( executorId: String, hostPort: String, cores: Int, + userClassPath: Seq[URL], env: SparkEnv) extends Actor with ActorLogReceive with ExecutorBackend with Logging { @@ -49,15 +52,21 @@ private[spark] class CoarseGrainedExecutorBackend( override def preStart() { logInfo("Connecting to driver: " + driverUrl) driver = context.actorSelection(driverUrl) - driver ! RegisterExecutor(executorId, hostPort, cores) + driver ! RegisterExecutor(executorId, hostPort, cores, extractLogUrls) context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) } + def extractLogUrls: Map[String, String] = { + val prefix = "SPARK_LOG_URL_" + sys.env.filterKeys(_.startsWith(prefix)) + .map(e => (e._1.substring(prefix.length).toLowerCase, e._2)) + } + override def receiveWithLogging = { case RegisteredExecutor => logInfo("Successfully registered with driver") val (hostname, _) = Utils.parseHostPort(hostPort) - executor = new Executor(executorId, hostname, env, isLocal = false) + executor = new Executor(executorId, hostname, env, userClassPath, isLocal = false) case RegisterExecutorFailed(message) => logError("Slave registration failed: " + message) @@ -84,8 +93,12 @@ private[spark] class CoarseGrainedExecutorBackend( } case x: DisassociatedEvent => - logError(s"Driver $x disassociated! Shutting down.") - System.exit(1) + if (x.remoteAddress == driver.anchorPath.address) { + logError(s"Driver $x disassociated! Shutting down.") + System.exit(1) + } else { + logWarning(s"Received irrelevant DisassociatedEvent $x") + } case StopExecutor => logInfo("Driver commanded a shutdown") @@ -107,7 +120,8 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { hostname: String, cores: Int, appId: String, - workerUrl: Option[String]) { + workerUrl: Option[String], + userClassPath: Seq[URL]) { SignalLogger.register(log) @@ -119,7 +133,11 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { val executorConf = new SparkConf val port = executorConf.getInt("spark.executor.port", 0) val (fetcher, _) = AkkaUtils.createActorSystem( - "driverPropsFetcher", hostname, port, executorConf, new SecurityManager(executorConf)) + "driverPropsFetcher", + hostname, + port, + executorConf, + new SecurityManager(executorConf)) val driver = fetcher.actorSelection(driverUrl) val timeout = AkkaUtils.askTimeout(executorConf) val fut = Patterns.ask(driver, RetrieveSparkProps, timeout) @@ -128,7 +146,15 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { fetcher.shutdown() // Create SparkEnv using properties we fetched from the driver. - val driverConf = new SparkConf().setAll(props) + val driverConf = new SparkConf() + for ((key, value) <- props) { + // this is required for SSL in standalone mode + if (SparkConf.isExecutorStartupConf(key)) { + driverConf.setIfMissing(key, value) + } else { + driverConf.set(key, value) + } + } val env = SparkEnv.createExecutorEnv( driverConf, executorId, hostname, port, cores, isLocal = false) @@ -140,7 +166,7 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { val sparkHostPort = hostname + ":" + boundPort env.actorSystem.actorOf( Props(classOf[CoarseGrainedExecutorBackend], - driverUrl, executorId, sparkHostPort, cores, env), + driverUrl, executorId, sparkHostPort, cores, userClassPath, env), name = "Executor") workerUrl.foreach { url => env.actorSystem.actorOf(Props(classOf[WorkerWatcher], url), name = "WorkerWatcher") @@ -150,20 +176,69 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { } def main(args: Array[String]) { - args.length match { - case x if x < 5 => - System.err.println( + var driverUrl: String = null + var executorId: String = null + var hostname: String = null + var cores: Int = 0 + var appId: String = null + var workerUrl: Option[String] = None + val userClassPath = new mutable.ListBuffer[URL]() + + var argv = args.toList + while (!argv.isEmpty) { + argv match { + case ("--driver-url") :: value :: tail => + driverUrl = value + argv = tail + case ("--executor-id") :: value :: tail => + executorId = value + argv = tail + case ("--hostname") :: value :: tail => + hostname = value + argv = tail + case ("--cores") :: value :: tail => + cores = value.toInt + argv = tail + case ("--app-id") :: value :: tail => + appId = value + argv = tail + case ("--worker-url") :: value :: tail => // Worker url is used in spark standalone mode to enforce fate-sharing with worker - "Usage: CoarseGrainedExecutorBackend " + - " [] ") - System.exit(1) + workerUrl = Some(value) + argv = tail + case ("--user-class-path") :: value :: tail => + userClassPath += new URL(value) + argv = tail + case Nil => + case tail => + System.err.println(s"Unrecognized options: ${tail.mkString(" ")}") + printUsageAndExit() + } + } - // NB: These arguments are provided by SparkDeploySchedulerBackend (for standalone mode) - // and CoarseMesosSchedulerBackend (for mesos mode). - case 5 => - run(args(0), args(1), args(2), args(3).toInt, args(4), None) - case x if x > 5 => - run(args(0), args(1), args(2), args(3).toInt, args(4), Some(args(5))) + if (driverUrl == null || executorId == null || hostname == null || cores <= 0 || + appId == null) { + printUsageAndExit() } + + run(driverUrl, executorId, hostname, cores, appId, workerUrl, userClassPath) } + + private def printUsageAndExit() = { + System.err.println( + """ + |"Usage: CoarseGrainedExecutorBackend [options] + | + | Options are: + | --driver-url + | --executor-id + | --hostname + | --cores + | --app-id + | --worker-url + | --user-class-path + |""".stripMargin) + System.exit(1) + } + } diff --git a/core/src/main/scala/org/apache/spark/executor/CommitDeniedException.scala b/core/src/main/scala/org/apache/spark/executor/CommitDeniedException.scala new file mode 100644 index 0000000000000..f7604a321f007 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/executor/CommitDeniedException.scala @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.executor + +import org.apache.spark.{TaskCommitDenied, TaskEndReason} + +/** + * Exception thrown when a task attempts to commit output to HDFS but is denied by the driver. + */ +class CommitDeniedException( + msg: String, + jobID: Int, + splitID: Int, + attemptID: Int) + extends Exception(msg) { + + def toTaskEndReason: TaskEndReason = new TaskCommitDenied(jobID, splitID, attemptID) + +} + diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 6660b98eb8ce9..b684fb704956b 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -19,6 +19,7 @@ package org.apache.spark.executor import java.io.File import java.lang.management.ManagementFactory +import java.net.URL import java.nio.ByteBuffer import java.util.concurrent._ @@ -33,7 +34,8 @@ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.scheduler._ import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.storage.{StorageLevel, TaskResultBlockId} -import org.apache.spark.util.{SparkUncaughtExceptionHandler, AkkaUtils, Utils} +import org.apache.spark.util.{ChildFirstURLClassLoader, MutableURLClassLoader, + SparkUncaughtExceptionHandler, AkkaUtils, Utils} /** * Spark executor used with Mesos, YARN, and the standalone scheduler. @@ -41,11 +43,15 @@ import org.apache.spark.util.{SparkUncaughtExceptionHandler, AkkaUtils, Utils} */ private[spark] class Executor( executorId: String, - slaveHostname: String, + executorHostname: String, env: SparkEnv, + userClassPath: Seq[URL] = Nil, isLocal: Boolean = false) extends Logging { + + logInfo(s"Starting executor ID $executorId on host $executorHostname") + // Application dependencies (added through SparkContext) that we've fetched so far on this node. // Each map holds the master's timestamp for the version of that file or JAR we got. private val currentFiles: HashMap[String, Long] = new HashMap[String, Long]() @@ -58,12 +64,12 @@ private[spark] class Executor( @volatile private var isStopped = false // No ip or host:port - just hostname - Utils.checkHost(slaveHostname, "Expected executed slave to be a hostname") + Utils.checkHost(executorHostname, "Expected executed slave to be a hostname") // must not have port specified. - assert (0 == Utils.parseHostPort(slaveHostname)._2) + assert (0 == Utils.parseHostPort(executorHostname)._2) // Make sure the local hostname we report matches the cluster scheduler's name for this host - Utils.setCustomHostname(slaveHostname) + Utils.setCustomHostname(executorHostname) if (!isLocal) { // Setup an uncaught exception handler for non-local mode. @@ -72,8 +78,10 @@ private[spark] class Executor( Thread.setDefaultUncaughtExceptionHandler(SparkUncaughtExceptionHandler) } + // Start worker thread pool + val threadPool = Utils.newDaemonCachedThreadPool("Executor task launch worker") + val executorSource = new ExecutorSource(this, executorId) - conf.set("spark.executor.id", executorId) if (!isLocal) { env.metricsSystem.registerSource(executorSource) @@ -99,9 +107,6 @@ private[spark] class Executor( // Limit of bytes for total size of results (default is 1GB) private val maxResultSize = Utils.getMaxResultSize(conf) - // Start worker thread pool - val threadPool = Utils.newDaemonCachedThreadPool("Executor task launch worker") - // Maintains the list of running tasks. private val runningTasks = new ConcurrentHashMap[Long, TaskRunner] @@ -203,10 +208,10 @@ private[spark] class Executor( val afterSerialization = System.currentTimeMillis() for (m <- task.metrics) { - m.executorDeserializeTime = taskStart - deserializeStartTime - m.executorRunTime = taskFinish - taskStart - m.jvmGCTime = gcTime - startGCTime - m.resultSerializationTime = afterSerialization - beforeSerialization + m.setExecutorDeserializeTime(taskStart - deserializeStartTime) + m.setExecutorRunTime(taskFinish - taskStart) + m.setJvmGCTime(gcTime - startGCTime) + m.setResultSerializationTime(afterSerialization - beforeSerialization) } val accumUpdates = Accumulators.values @@ -248,6 +253,11 @@ private[spark] class Executor( execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled)) } + case cDE: CommitDeniedException => { + val reason = cDE.toTaskEndReason + execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason)) + } + case t: Throwable => { // Attempt to exit cleanly by informing the driver of our failure. // If anything goes wrong (or this was a fatal exception), we will delegate to @@ -257,8 +267,8 @@ private[spark] class Executor( val serviceTime = System.currentTimeMillis() - taskStart val metrics = attemptedTask.flatMap(t => t.metrics) for (m <- metrics) { - m.executorRunTime = serviceTime - m.jvmGCTime = gcTime - startGCTime + m.setExecutorRunTime(serviceTime) + m.setJvmGCTime(gcTime - startGCTime) } val reason = new ExceptionFailure(t, metrics) execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason)) @@ -286,17 +296,23 @@ private[spark] class Executor( * created by the interpreter to the search path */ private def createClassLoader(): MutableURLClassLoader = { + // Bootstrap the list of jars with the user class path. + val now = System.currentTimeMillis() + userClassPath.foreach { url => + currentJars(url.getPath().split("/").last) = now + } + val currentLoader = Utils.getContextOrSparkClassLoader // For each of the jars in the jarSet, add them to the class loader. // We assume each of the files has already been fetched. - val urls = currentJars.keySet.map { uri => + val urls = userClassPath.toArray ++ currentJars.keySet.map { uri => new File(uri.split("/").last).toURI.toURL - }.toArray - val userClassPathFirst = conf.getBoolean("spark.files.userClassPathFirst", false) - userClassPathFirst match { - case true => new ChildExecutorURLClassLoader(urls, currentLoader) - case false => new ExecutorURLClassLoader(urls, currentLoader) + } + if (conf.getBoolean("spark.executor.userClassPathFirst", false)) { + new ChildFirstURLClassLoader(urls, currentLoader) + } else { + new MutableURLClassLoader(urls, currentLoader) } } @@ -309,7 +325,7 @@ private[spark] class Executor( if (classUri != null) { logInfo("Using REPL class URI: " + classUri) val userClassPathFirst: java.lang.Boolean = - conf.getBoolean("spark.files.userClassPathFirst", false) + conf.getBoolean("spark.executor.userClassPathFirst", false) try { val klass = Class.forName("org.apache.spark.repl.ExecutorClassLoader") .asInstanceOf[Class[_ <: ClassLoader]] @@ -342,18 +358,23 @@ private[spark] class Executor( env.securityManager, hadoopConf, timestamp, useCache = !isLocal) currentFiles(name) = timestamp } - for ((name, timestamp) <- newJars if currentJars.getOrElse(name, -1L) < timestamp) { - logInfo("Fetching " + name + " with timestamp " + timestamp) - // Fetch file with useCache mode, close cache for local mode. - Utils.fetchFile(name, new File(SparkFiles.getRootDirectory), conf, - env.securityManager, hadoopConf, timestamp, useCache = !isLocal) - currentJars(name) = timestamp - // Add it to our class loader + for ((name, timestamp) <- newJars) { val localName = name.split("/").last - val url = new File(SparkFiles.getRootDirectory, localName).toURI.toURL - if (!urlClassLoader.getURLs.contains(url)) { - logInfo("Adding " + url + " to class loader") - urlClassLoader.addURL(url) + val currentTimeStamp = currentJars.get(name) + .orElse(currentJars.get(localName)) + .getOrElse(-1L) + if (currentTimeStamp < timestamp) { + logInfo("Fetching " + name + " with timestamp " + timestamp) + // Fetch file with useCache mode, close cache for local mode. + Utils.fetchFile(name, new File(SparkFiles.getRootDirectory), conf, + env.securityManager, hadoopConf, timestamp, useCache = !isLocal) + currentJars(name) = timestamp + // Add it to our class loader + val url = new File(SparkFiles.getRootDirectory, localName).toURI.toURL + if (!urlClassLoader.getURLs.contains(url)) { + logInfo("Adding " + url + " to class loader") + urlClassLoader.addURL(url) + } } } } @@ -376,11 +397,12 @@ private[spark] class Executor( val curGCTime = gcTime for (taskRunner <- runningTasks.values()) { - if (!taskRunner.attemptedTask.isEmpty) { + if (taskRunner.attemptedTask.nonEmpty) { Option(taskRunner.task).flatMap(_.metrics).foreach { metrics => - metrics.updateShuffleReadMetrics + metrics.updateShuffleReadMetrics() metrics.updateInputMetrics() - metrics.jvmGCTime = curGCTime - taskRunner.startGCTime + metrics.setJvmGCTime(curGCTime - taskRunner.startGCTime) + if (isLocal) { // JobProgressListener will hold an reference of it during // onExecutorMetricsUpdate(), then JobProgressListener can not see diff --git a/core/src/main/scala/org/apache/spark/executor/ExecutorURLClassLoader.scala b/core/src/main/scala/org/apache/spark/executor/ExecutorURLClassLoader.scala deleted file mode 100644 index 218ed7b5d2d39..0000000000000 --- a/core/src/main/scala/org/apache/spark/executor/ExecutorURLClassLoader.scala +++ /dev/null @@ -1,74 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.executor - -import java.net.{URLClassLoader, URL} - -import org.apache.spark.util.ParentClassLoader - -/** - * The addURL method in URLClassLoader is protected. We subclass it to make this accessible. - * We also make changes so user classes can come before the default classes. - */ - -private[spark] trait MutableURLClassLoader extends ClassLoader { - def addURL(url: URL) - def getURLs: Array[URL] -} - -private[spark] class ChildExecutorURLClassLoader(urls: Array[URL], parent: ClassLoader) - extends MutableURLClassLoader { - - private object userClassLoader extends URLClassLoader(urls, null){ - override def addURL(url: URL) { - super.addURL(url) - } - override def findClass(name: String): Class[_] = { - super.findClass(name) - } - } - - private val parentClassLoader = new ParentClassLoader(parent) - - override def findClass(name: String): Class[_] = { - try { - userClassLoader.findClass(name) - } catch { - case e: ClassNotFoundException => { - parentClassLoader.loadClass(name) - } - } - } - - def addURL(url: URL) { - userClassLoader.addURL(url) - } - - def getURLs() = { - userClassLoader.getURLs() - } -} - -private[spark] class ExecutorURLClassLoader(urls: Array[URL], parent: ClassLoader) - extends URLClassLoader(urls, parent) with MutableURLClassLoader { - - override def addURL(url: URL) { - super.addURL(url) - } -} - diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala index 7eb10f95e023b..07b152651dedf 100644 --- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala @@ -19,7 +19,6 @@ package org.apache.spark.executor import java.util.concurrent.atomic.AtomicLong -import org.apache.spark.executor.DataReadMethod import org.apache.spark.executor.DataReadMethod.DataReadMethod import scala.collection.mutable.ArrayBuffer @@ -44,42 +43,62 @@ class TaskMetrics extends Serializable { /** * Host's name the task runs on */ - var hostname: String = _ - + private var _hostname: String = _ + def hostname = _hostname + private[spark] def setHostname(value: String) = _hostname = value + /** * Time taken on the executor to deserialize this task */ - var executorDeserializeTime: Long = _ - + private var _executorDeserializeTime: Long = _ + def executorDeserializeTime = _executorDeserializeTime + private[spark] def setExecutorDeserializeTime(value: Long) = _executorDeserializeTime = value + + /** * Time the executor spends actually running the task (including fetching shuffle data) */ - var executorRunTime: Long = _ - + private var _executorRunTime: Long = _ + def executorRunTime = _executorRunTime + private[spark] def setExecutorRunTime(value: Long) = _executorRunTime = value + /** * The number of bytes this task transmitted back to the driver as the TaskResult */ - var resultSize: Long = _ + private var _resultSize: Long = _ + def resultSize = _resultSize + private[spark] def setResultSize(value: Long) = _resultSize = value + /** * Amount of time the JVM spent in garbage collection while executing this task */ - var jvmGCTime: Long = _ + private var _jvmGCTime: Long = _ + def jvmGCTime = _jvmGCTime + private[spark] def setJvmGCTime(value: Long) = _jvmGCTime = value /** * Amount of time spent serializing the task result */ - var resultSerializationTime: Long = _ + private var _resultSerializationTime: Long = _ + def resultSerializationTime = _resultSerializationTime + private[spark] def setResultSerializationTime(value: Long) = _resultSerializationTime = value /** * The number of in-memory bytes spilled by this task */ - var memoryBytesSpilled: Long = _ + private var _memoryBytesSpilled: Long = _ + def memoryBytesSpilled = _memoryBytesSpilled + private[spark] def incMemoryBytesSpilled(value: Long) = _memoryBytesSpilled += value + private[spark] def decMemoryBytesSpilled(value: Long) = _memoryBytesSpilled -= value /** * The number of on-disk bytes spilled by this task */ - var diskBytesSpilled: Long = _ + private var _diskBytesSpilled: Long = _ + def diskBytesSpilled = _diskBytesSpilled + def incDiskBytesSpilled(value: Long) = _diskBytesSpilled += value + def decDiskBytesSpilled(value: Long) = _diskBytesSpilled -= value /** * If this task reads from a HadoopRDD or from persisted data, metrics on how much data was read @@ -158,8 +177,8 @@ class TaskMetrics extends Serializable { * Once https://issues.apache.org/jira/browse/SPARK-5225 is addressed, * we can store all the different inputMetrics (one per readMethod). */ - private[spark] def getInputMetricsForReadMethod(readMethod: DataReadMethod): - InputMetrics =synchronized { + private[spark] def getInputMetricsForReadMethod( + readMethod: DataReadMethod): InputMetrics = synchronized { _inputMetrics match { case None => val metrics = new InputMetrics(readMethod) @@ -175,18 +194,22 @@ class TaskMetrics extends Serializable { /** * Aggregates shuffle read metrics for all registered dependencies into shuffleReadMetrics. */ - private[spark] def updateShuffleReadMetrics() = synchronized { - val merged = new ShuffleReadMetrics() - for (depMetrics <- depsShuffleReadMetrics) { - merged.fetchWaitTime += depMetrics.fetchWaitTime - merged.localBlocksFetched += depMetrics.localBlocksFetched - merged.remoteBlocksFetched += depMetrics.remoteBlocksFetched - merged.remoteBytesRead += depMetrics.remoteBytesRead + private[spark] def updateShuffleReadMetrics(): Unit = synchronized { + if (!depsShuffleReadMetrics.isEmpty) { + val merged = new ShuffleReadMetrics() + for (depMetrics <- depsShuffleReadMetrics) { + merged.incFetchWaitTime(depMetrics.fetchWaitTime) + merged.incLocalBlocksFetched(depMetrics.localBlocksFetched) + merged.incRemoteBlocksFetched(depMetrics.remoteBlocksFetched) + merged.incRemoteBytesRead(depMetrics.remoteBytesRead) + merged.incLocalBytesRead(depMetrics.localBytesRead) + merged.incRecordsRead(depMetrics.recordsRead) + } + _shuffleReadMetrics = Some(merged) } - _shuffleReadMetrics = Some(merged) } - private[spark] def updateInputMetrics() = synchronized { + private[spark] def updateInputMetrics(): Unit = synchronized { inputMetrics.foreach(_.updateBytesRead()) } } @@ -223,27 +246,31 @@ object DataWriteMethod extends Enumeration with Serializable { @DeveloperApi case class InputMetrics(readMethod: DataReadMethod.Value) { - private val _bytesRead: AtomicLong = new AtomicLong() + /** + * This is volatile so that it is visible to the updater thread. + */ + @volatile @transient var bytesReadCallback: Option[() => Long] = None /** * Total bytes read. */ - def bytesRead: Long = _bytesRead.get() - @volatile @transient var bytesReadCallback: Option[() => Long] = None + private var _bytesRead: Long = _ + def bytesRead: Long = _bytesRead + def incBytesRead(bytes: Long) = _bytesRead += bytes /** - * Adds additional bytes read for this read method. + * Total records read. */ - def addBytesRead(bytes: Long) = { - _bytesRead.addAndGet(bytes) - } + private var _recordsRead: Long = _ + def recordsRead: Long = _recordsRead + def incRecordsRead(records: Long) = _recordsRead += records /** * Invoke the bytesReadCallback and mutate bytesRead. */ def updateBytesRead() { bytesReadCallback.foreach { c => - _bytesRead.set(c()) + _bytesRead = c() } } @@ -265,7 +292,16 @@ case class OutputMetrics(writeMethod: DataWriteMethod.Value) { /** * Total bytes written */ - var bytesWritten: Long = 0L + private var _bytesWritten: Long = _ + def bytesWritten = _bytesWritten + private[spark] def setBytesWritten(value : Long) = _bytesWritten = value + + /** + * Total records written + */ + private var _recordsWritten: Long = 0L + def recordsWritten = _recordsWritten + private[spark] def setRecordsWritten(value: Long) = _recordsWritten = value } /** @@ -274,32 +310,64 @@ case class OutputMetrics(writeMethod: DataWriteMethod.Value) { */ @DeveloperApi class ShuffleReadMetrics extends Serializable { - /** - * Number of blocks fetched in this shuffle by this task (remote or local) - */ - def totalBlocksFetched: Int = remoteBlocksFetched + localBlocksFetched - /** * Number of remote blocks fetched in this shuffle by this task */ - var remoteBlocksFetched: Int = _ - + private var _remoteBlocksFetched: Int = _ + def remoteBlocksFetched = _remoteBlocksFetched + private[spark] def incRemoteBlocksFetched(value: Int) = _remoteBlocksFetched += value + private[spark] def decRemoteBlocksFetched(value: Int) = _remoteBlocksFetched -= value + /** * Number of local blocks fetched in this shuffle by this task */ - var localBlocksFetched: Int = _ + private var _localBlocksFetched: Int = _ + def localBlocksFetched = _localBlocksFetched + private[spark] def incLocalBlocksFetched(value: Int) = _localBlocksFetched += value + private[spark] def decLocalBlocksFetched(value: Int) = _localBlocksFetched -= value /** * Time the task spent waiting for remote shuffle blocks. This only includes the time * blocking on shuffle input data. For instance if block B is being fetched while the task is * still not finished processing block A, it is not considered to be blocking on block B. */ - var fetchWaitTime: Long = _ - + private var _fetchWaitTime: Long = _ + def fetchWaitTime = _fetchWaitTime + private[spark] def incFetchWaitTime(value: Long) = _fetchWaitTime += value + private[spark] def decFetchWaitTime(value: Long) = _fetchWaitTime -= value + /** * Total number of remote bytes read from the shuffle by this task */ - var remoteBytesRead: Long = _ + private var _remoteBytesRead: Long = _ + def remoteBytesRead = _remoteBytesRead + private[spark] def incRemoteBytesRead(value: Long) = _remoteBytesRead += value + private[spark] def decRemoteBytesRead(value: Long) = _remoteBytesRead -= value + + /** + * Shuffle data that was read from the local disk (as opposed to from a remote executor). + */ + private var _localBytesRead: Long = _ + def localBytesRead = _localBytesRead + private[spark] def incLocalBytesRead(value: Long) = _localBytesRead += value + + /** + * Total bytes fetched in the shuffle by this task (both remote and local). + */ + def totalBytesRead = _remoteBytesRead + _localBytesRead + + /** + * Number of blocks fetched in this shuffle by this task (remote or local) + */ + def totalBlocksFetched = _remoteBlocksFetched + _localBlocksFetched + + /** + * Total number of records read from the shuffle by this task + */ + private var _recordsRead: Long = _ + def recordsRead = _recordsRead + private[spark] def incRecordsRead(value: Long) = _recordsRead += value + private[spark] def decRecordsRead(value: Long) = _recordsRead -= value } /** @@ -311,10 +379,25 @@ class ShuffleWriteMetrics extends Serializable { /** * Number of bytes written for the shuffle by this task */ - @volatile var shuffleBytesWritten: Long = _ - + @volatile private var _shuffleBytesWritten: Long = _ + def shuffleBytesWritten = _shuffleBytesWritten + private[spark] def incShuffleBytesWritten(value: Long) = _shuffleBytesWritten += value + private[spark] def decShuffleBytesWritten(value: Long) = _shuffleBytesWritten -= value + /** * Time the task spent blocking on writes to disk or buffer cache, in nanoseconds */ - @volatile var shuffleWriteTime: Long = _ + @volatile private var _shuffleWriteTime: Long = _ + def shuffleWriteTime= _shuffleWriteTime + private[spark] def incShuffleWriteTime(value: Long) = _shuffleWriteTime += value + private[spark] def decShuffleWriteTime(value: Long) = _shuffleWriteTime -= value + + /** + * Total number of records written to the shuffle by this task + */ + @volatile private var _shuffleRecordsWritten: Long = _ + def shuffleRecordsWritten = _shuffleRecordsWritten + private[spark] def incShuffleRecordsWritten(value: Long) = _shuffleRecordsWritten += value + private[spark] def decShuffleRecordsWritten(value: Long) = _shuffleRecordsWritten -= value + private[spark] def setShuffleRecordsWritten(value: Long) = _shuffleRecordsWritten = value } diff --git a/core/src/main/scala/org/apache/spark/metrics/MetricsConfig.scala b/core/src/main/scala/org/apache/spark/metrics/MetricsConfig.scala index 1b7a5d1f1980a..8edf493780687 100644 --- a/core/src/main/scala/org/apache/spark/metrics/MetricsConfig.scala +++ b/core/src/main/scala/org/apache/spark/metrics/MetricsConfig.scala @@ -28,12 +28,12 @@ import org.apache.spark.util.Utils private[spark] class MetricsConfig(val configFile: Option[String]) extends Logging { - val DEFAULT_PREFIX = "*" - val INSTANCE_REGEX = "^(\\*|[a-zA-Z]+)\\.(.+)".r - val METRICS_CONF = "metrics.properties" + private val DEFAULT_PREFIX = "*" + private val INSTANCE_REGEX = "^(\\*|[a-zA-Z]+)\\.(.+)".r + private val DEFAULT_METRICS_CONF_FILENAME = "metrics.properties" - val properties = new Properties() - var propertyCategories: mutable.HashMap[String, Properties] = null + private[metrics] val properties = new Properties() + private[metrics] var propertyCategories: mutable.HashMap[String, Properties] = null private def setDefaultProperties(prop: Properties) { prop.setProperty("*.sink.servlet.class", "org.apache.spark.metrics.sink.MetricsServlet") @@ -47,20 +47,22 @@ private[spark] class MetricsConfig(val configFile: Option[String]) extends Loggi setDefaultProperties(properties) // If spark.metrics.conf is not set, try to get file in class path - var is: InputStream = null - try { - is = configFile match { - case Some(f) => new FileInputStream(f) - case None => Utils.getSparkClassLoader.getResourceAsStream(METRICS_CONF) + val isOpt: Option[InputStream] = configFile.map(new FileInputStream(_)).orElse { + try { + Option(Utils.getSparkClassLoader.getResourceAsStream(DEFAULT_METRICS_CONF_FILENAME)) + } catch { + case e: Exception => + logError("Error loading default configuration file", e) + None } + } - if (is != null) { + isOpt.foreach { is => + try { properties.load(is) + } finally { + is.close() } - } catch { - case e: Exception => logError("Error loading configure file", e) - } finally { - if (is != null) is.close() } propertyCategories = subProperties(properties, INSTANCE_REGEX) diff --git a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala index 45633e3de01dd..345db36630fd5 100644 --- a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala +++ b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala @@ -130,8 +130,8 @@ private[spark] class MetricsSystem private ( if (appId.isDefined && executorId.isDefined) { MetricRegistry.name(appId.get, executorId.get, source.sourceName) } else { - // Only Driver and Executor are set spark.app.id and spark.executor.id. - // For instance, Master and Worker are not related to a specific application. + // Only Driver and Executor set spark.app.id and spark.executor.id. + // Other instance types, e.g. Master and Worker, are not related to a specific application. val warningMsg = s"Using default name $defaultName for source because %s is not set." if (appId.isEmpty) { logWarning(warningMsg.format("spark.app.id")) } if (executorId.isEmpty) { logWarning(warningMsg.format("spark.executor.id")) } @@ -191,7 +191,10 @@ private[spark] class MetricsSystem private ( sinks += sink.asInstanceOf[Sink] } } catch { - case e: Exception => logError("Sink class " + classPath + " cannot be instantialized", e) + case e: Exception => { + logError("Sink class " + classPath + " cannot be instantialized") + throw e + } } } } diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala index d7b5f5c40efae..2d25ebd66159f 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala @@ -22,7 +22,7 @@ import java.util.Properties import java.util.concurrent.TimeUnit import com.codahale.metrics.MetricRegistry -import com.codahale.metrics.graphite.{Graphite, GraphiteReporter} +import com.codahale.metrics.graphite.{GraphiteUDP, Graphite, GraphiteReporter} import org.apache.spark.SecurityManager import org.apache.spark.metrics.MetricsSystem @@ -38,6 +38,7 @@ private[spark] class GraphiteSink(val property: Properties, val registry: Metric val GRAPHITE_KEY_PERIOD = "period" val GRAPHITE_KEY_UNIT = "unit" val GRAPHITE_KEY_PREFIX = "prefix" + val GRAPHITE_KEY_PROTOCOL = "protocol" def propertyToOption(prop: String): Option[String] = Option(property.getProperty(prop)) @@ -66,7 +67,11 @@ private[spark] class GraphiteSink(val property: Properties, val registry: Metric MetricsSystem.checkMinimalPollingPeriod(pollUnit, pollPeriod) - val graphite: Graphite = new Graphite(new InetSocketAddress(host, port)) + val graphite = propertyToOption(GRAPHITE_KEY_PROTOCOL).map(_.toLowerCase) match { + case Some("udp") => new GraphiteUDP(new InetSocketAddress(host, port)) + case Some("tcp") | None => new Graphite(new InetSocketAddress(host, port)) + case Some(p) => throw new Exception(s"Invalid Graphite protocol: $p") + } val reporter: GraphiteReporter = GraphiteReporter.forRegistry(registry) .convertDurationsTo(TimeUnit.MILLISECONDS) diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/Slf4jSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/Slf4jSink.scala new file mode 100644 index 0000000000000..e8b3074e8f1a6 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/metrics/sink/Slf4jSink.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.metrics.sink + +import java.util.Properties +import java.util.concurrent.TimeUnit + +import com.codahale.metrics.{Slf4jReporter, MetricRegistry} + +import org.apache.spark.SecurityManager +import org.apache.spark.metrics.MetricsSystem + +private[spark] class Slf4jSink( + val property: Properties, + val registry: MetricRegistry, + securityMgr: SecurityManager) + extends Sink { + val SLF4J_DEFAULT_PERIOD = 10 + val SLF4J_DEFAULT_UNIT = "SECONDS" + + val SLF4J_KEY_PERIOD = "period" + val SLF4J_KEY_UNIT = "unit" + + val pollPeriod = Option(property.getProperty(SLF4J_KEY_PERIOD)) match { + case Some(s) => s.toInt + case None => SLF4J_DEFAULT_PERIOD + } + + val pollUnit: TimeUnit = Option(property.getProperty(SLF4J_KEY_UNIT)) match { + case Some(s) => TimeUnit.valueOf(s.toUpperCase()) + case None => TimeUnit.valueOf(SLF4J_DEFAULT_UNIT) + } + + MetricsSystem.checkMinimalPollingPeriod(pollUnit, pollPeriod) + + val reporter: Slf4jReporter = Slf4jReporter.forRegistry(registry) + .convertDurationsTo(TimeUnit.MILLISECONDS) + .convertRatesTo(TimeUnit.SECONDS) + .build() + + override def start() { + reporter.start(pollPeriod, pollUnit) + } + + override def stop() { + reporter.stop() + } + + override def report() { + reporter.report() + } +} + diff --git a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala index 03c4137ca0a81..ee22c6656e69e 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala @@ -184,14 +184,16 @@ private[nio] class ConnectionManager( // to be able to track asynchronous messages private val idCount: AtomicInteger = new AtomicInteger(1) + private val writeRunnableStarted: HashSet[SelectionKey] = new HashSet[SelectionKey]() + private val readRunnableStarted: HashSet[SelectionKey] = new HashSet[SelectionKey]() + private val selectorThread = new Thread("connection-manager-thread") { override def run() = ConnectionManager.this.run() } selectorThread.setDaemon(true) + // start this thread last, since it invokes run(), which accesses members above selectorThread.start() - private val writeRunnableStarted: HashSet[SelectionKey] = new HashSet[SelectionKey]() - private def triggerWrite(key: SelectionKey) { val conn = connectionsByKey.getOrElse(key, null) if (conn == null) return @@ -232,7 +234,6 @@ private[nio] class ConnectionManager( } ) } - private val readRunnableStarted: HashSet[SelectionKey] = new HashSet[SelectionKey]() private def triggerRead(key: SelectionKey) { val conn = connectionsByKey.getOrElse(key, null) diff --git a/core/src/main/scala/org/apache/spark/package.scala b/core/src/main/scala/org/apache/spark/package.scala index 5ad73c3d27f47..b6249b492150a 100644 --- a/core/src/main/scala/org/apache/spark/package.scala +++ b/core/src/main/scala/org/apache/spark/package.scala @@ -27,8 +27,7 @@ package org.apache * contains operations available only on RDDs of Doubles; and * [[org.apache.spark.rdd.SequenceFileRDDFunctions]] contains operations available on RDDs that can * be saved as SequenceFiles. These operations are automatically available on any RDD of the right - * type (e.g. RDD[(Int, Int)] through implicit conversions except `saveAsSequenceFile`. You need to - * `import org.apache.spark.SparkContext._` to make `saveAsSequenceFile` work. + * type (e.g. RDD[(Int, Int)] through implicit conversions. * * Java programmers should reference the [[org.apache.spark.api.java]] package * for Spark programming APIs in Java. diff --git a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala index 70edf191d928a..07398a6fa62f6 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala @@ -159,8 +159,8 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part: for ((it, depNum) <- rddIterators) { map.insertAll(it.map(pair => (pair._1, new CoGroupValue(pair._2, depNum)))) } - context.taskMetrics.memoryBytesSpilled += map.memoryBytesSpilled - context.taskMetrics.diskBytesSpilled += map.diskBytesSpilled + context.taskMetrics.incMemoryBytesSpilled(map.memoryBytesSpilled) + context.taskMetrics.incDiskBytesSpilled(map.diskBytesSpilled) new InterruptibleIterator(context, map.iterator.asInstanceOf[Iterator[(K, Array[Iterable[_]])]]) } diff --git a/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala index e66f83bb34e30..03afc289736bb 100644 --- a/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala @@ -213,7 +213,14 @@ class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable { } else { basicBucketFunction _ } - self.mapPartitions(histogramPartition(bucketFunction)).reduce(mergeCounters) + if (self.partitions.length == 0) { + new Array[Long](buckets.length - 1) + } else { + // reduce() requires a non-empty RDD. This works because the mapPartitions will make + // non-empty partitions out of empty ones. But it doesn't handle the no-partitions case, + // which is below + self.mapPartitions(histogramPartition(bucketFunction)).reduce(mergeCounters) + } } } diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index 3b99d3a6cafd1..486e86ce1bb19 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -35,16 +35,18 @@ import org.apache.hadoop.mapred.Reporter import org.apache.hadoop.mapred.JobID import org.apache.hadoop.mapred.TaskAttemptID import org.apache.hadoop.mapred.TaskID +import org.apache.hadoop.mapred.lib.CombineFileSplit import org.apache.hadoop.util.ReflectionUtils import org.apache.spark._ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.executor.{DataReadMethod, InputMetrics} +import org.apache.spark.executor.DataReadMethod import org.apache.spark.rdd.HadoopRDD.HadoopMapPartitionsWithSplitRDD import org.apache.spark.util.{NextIterator, Utils} import org.apache.spark.scheduler.{HostTaskLocation, HDFSCacheTaskLocation} +import org.apache.spark.storage.StorageLevel /** * A Spark split class that wraps around a Hadoop InputSplit. @@ -218,13 +220,13 @@ class HadoopRDD[K, V]( // Find a function that will return the FileSystem bytes read by this thread. Do this before // creating RecordReader, because RecordReader's constructor might read some bytes - val bytesReadCallback = inputMetrics.bytesReadCallback.orElse( + val bytesReadCallback = inputMetrics.bytesReadCallback.orElse { split.inputSplit.value match { - case split: FileSplit => - SparkHadoopUtil.get.getFSBytesReadOnThreadCallback(split.getPath, jobConf) + case _: FileSplit | _: CombineFileSplit => + SparkHadoopUtil.get.getFSBytesReadOnThreadCallback() case _ => None } - ) + } inputMetrics.setBytesReadCallback(bytesReadCallback) var reader: RecordReader[K, V] = null @@ -245,6 +247,9 @@ class HadoopRDD[K, V]( case eof: EOFException => finished = true } + if (!finished) { + inputMetrics.incRecordsRead(1) + } (key, value) } @@ -253,11 +258,12 @@ class HadoopRDD[K, V]( reader.close() if (bytesReadCallback.isDefined) { inputMetrics.updateBytesRead() - } else if (split.inputSplit.value.isInstanceOf[FileSplit]) { + } else if (split.inputSplit.value.isInstanceOf[FileSplit] || + split.inputSplit.value.isInstanceOf[CombineFileSplit]) { // If we can't get the bytes read from the FS stats, fall back to the split size, // which may be inaccurate. try { - inputMetrics.addBytesRead(split.inputSplit.value.getLength) + inputMetrics.incBytesRead(split.inputSplit.value.getLength) } catch { case e: java.io.IOException => logWarning("Unable to get input size to set InputMetrics for task", e) @@ -305,6 +311,15 @@ class HadoopRDD[K, V]( // Do nothing. Hadoop RDD should not be checkpointed. } + override def persist(storageLevel: StorageLevel): this.type = { + if (storageLevel.deserialized) { + logWarning("Caching NewHadoopRDDs as deserialized objects usually leads to undesired" + + " behavior because Hadoop's RecordReader reuses the same Writable object for all records." + + " Use a map transformation to make copies of the records.") + } + super.persist(storageLevel) + } + def getConf: Configuration = getJobConf() } diff --git a/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala b/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala index 642a12c1edf6c..e2267861e79df 100644 --- a/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala @@ -62,11 +62,11 @@ class JdbcRDD[T: ClassTag]( override def getPartitions: Array[Partition] = { // bounds are inclusive, hence the + 1 here and - 1 on end - val length = 1 + upperBound - lowerBound + val length = BigInt(1) + upperBound - lowerBound (0 until numPartitions).map(i => { - val start = lowerBound + ((i * length) / numPartitions).toLong - val end = lowerBound + (((i + 1) * length) / numPartitions).toLong - 1 - new JdbcPartition(i, start, end) + val start = lowerBound + ((i * length) / numPartitions) + val end = lowerBound + (((i + 1) * length) / numPartitions) - 1 + new JdbcPartition(i, start.toLong, end.toLong) }).toArray } @@ -99,21 +99,21 @@ class JdbcRDD[T: ClassTag]( override def close() { try { - if (null != rs && ! rs.isClosed()) { + if (null != rs) { rs.close() } } catch { case e: Exception => logWarning("Exception closing resultset", e) } try { - if (null != stmt && ! stmt.isClosed()) { + if (null != stmt) { stmt.close() } } catch { case e: Exception => logWarning("Exception closing statement", e) } try { - if (null != conn && ! conn.isClosed()) { + if (null != conn) { conn.close() } logInfo("closed connection") diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala index 890ec677c2690..7fb94840df99c 100644 --- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala @@ -25,20 +25,17 @@ import scala.reflect.ClassTag import org.apache.hadoop.conf.{Configurable, Configuration} import org.apache.hadoop.io.Writable import org.apache.hadoop.mapreduce._ -import org.apache.hadoop.mapreduce.lib.input.FileSplit +import org.apache.hadoop.mapreduce.lib.input.{CombineFileSplit, FileSplit} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.input.WholeTextFileInputFormat -import org.apache.spark.InterruptibleIterator -import org.apache.spark.Logging -import org.apache.spark.Partition -import org.apache.spark.SerializableWritable -import org.apache.spark.{SparkContext, TaskContext} -import org.apache.spark.executor.{DataReadMethod, InputMetrics} +import org.apache.spark._ +import org.apache.spark.executor.DataReadMethod import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil import org.apache.spark.rdd.NewHadoopRDD.NewHadoopMapPartitionsWithSplitRDD import org.apache.spark.util.Utils import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.storage.StorageLevel private[spark] class NewHadoopPartition( rddId: Int, @@ -114,13 +111,13 @@ class NewHadoopRDD[K, V]( // Find a function that will return the FileSystem bytes read by this thread. Do this before // creating RecordReader, because RecordReader's constructor might read some bytes - val bytesReadCallback = inputMetrics.bytesReadCallback.orElse( + val bytesReadCallback = inputMetrics.bytesReadCallback.orElse { split.serializableHadoopSplit.value match { - case split: FileSplit => - SparkHadoopUtil.get.getFSBytesReadOnThreadCallback(split.getPath, conf) + case _: FileSplit | _: CombineFileSplit => + SparkHadoopUtil.get.getFSBytesReadOnThreadCallback() case _ => None } - ) + } inputMetrics.setBytesReadCallback(bytesReadCallback) val attemptId = newTaskAttemptID(jobTrackerId, id, isMap = true, split.index, 0) @@ -154,6 +151,9 @@ class NewHadoopRDD[K, V]( throw new java.util.NoSuchElementException("End of stream") } havePair = false + if (!finished) { + inputMetrics.incRecordsRead(1) + } (reader.getCurrentKey, reader.getCurrentValue) } @@ -162,11 +162,12 @@ class NewHadoopRDD[K, V]( reader.close() if (bytesReadCallback.isDefined) { inputMetrics.updateBytesRead() - } else if (split.serializableHadoopSplit.value.isInstanceOf[FileSplit]) { + } else if (split.serializableHadoopSplit.value.isInstanceOf[FileSplit] || + split.serializableHadoopSplit.value.isInstanceOf[CombineFileSplit]) { // If we can't get the bytes read from the FS stats, fall back to the split size, // which may be inaccurate. try { - inputMetrics.addBytesRead(split.serializableHadoopSplit.value.getLength) + inputMetrics.incBytesRead(split.serializableHadoopSplit.value.getLength) } catch { case e: java.io.IOException => logWarning("Unable to get input size to set InputMetrics for task", e) @@ -209,6 +210,16 @@ class NewHadoopRDD[K, V]( locs.getOrElse(split.getLocations.filter(_ != "localhost")) } + override def persist(storageLevel: StorageLevel): this.type = { + if (storageLevel.deserialized) { + logWarning("Caching NewHadoopRDDs as deserialized objects usually leads to undesired" + + " behavior because Hadoop's RecordReader reuses the same Writable object for all records." + + " Use a map transformation to make copies of the records.") + } + super.persist(storageLevel) + } + + def getConf: Configuration = confBroadcast.value.value } diff --git a/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala index 144f679a59460..6fdfdb734d1b8 100644 --- a/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala @@ -75,4 +75,27 @@ class OrderedRDDFunctions[K : Ordering : ClassTag, new ShuffledRDD[K, V, V](self, partitioner).setKeyOrdering(ordering) } + /** + * Returns an RDD containing only the elements in the the inclusive range `lower` to `upper`. + * If the RDD has been partitioned using a `RangePartitioner`, then this operation can be + * performed efficiently by only scanning the partitions that might contain matching elements. + * Otherwise, a standard `filter` is applied to all partitions. + */ + def filterByRange(lower: K, upper: K): RDD[P] = { + + def inRange(k: K): Boolean = ordering.gteq(k, lower) && ordering.lteq(k, upper) + + val rddToFilter: RDD[P] = self.partitioner match { + case Some(rp: RangePartitioner[K, V]) => { + val partitionIndicies = (rp.getPartition(lower), rp.getPartition(upper)) match { + case (l, u) => Math.min(l, u) to Math.max(l, u) + } + PartitionPruningRDD.create(self, partitionIndicies.contains) + } + case _ => + self + } + rddToFilter.filter { case (k, v) => inRange(k) } + } + } diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index e43e5066655b9..955b42c3baaa1 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -34,7 +34,7 @@ import org.apache.hadoop.io.SequenceFile.CompressionType import org.apache.hadoop.io.compress.CompressionCodec import org.apache.hadoop.mapred.{FileOutputCommitter, FileOutputFormat, JobConf, OutputFormat} import org.apache.hadoop.mapreduce.{Job => NewAPIHadoopJob, OutputFormat => NewOutputFormat, -RecordWriter => NewRecordWriter} + RecordWriter => NewRecordWriter} import org.apache.spark._ import org.apache.spark.Partitioner.defaultPartitioner @@ -990,11 +990,11 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) val committer = format.getOutputCommitter(hadoopContext) committer.setupTask(hadoopContext) - val (outputMetrics, bytesWrittenCallback) = initHadoopOutputMetrics(context, config) + val (outputMetrics, bytesWrittenCallback) = initHadoopOutputMetrics(context) val writer = format.getRecordWriter(hadoopContext).asInstanceOf[NewRecordWriter[K,V]] + var recordsWritten = 0L try { - var recordsWritten = 0L while (iter.hasNext) { val pair = iter.next() writer.write(pair._1, pair._2) @@ -1007,7 +1007,8 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) writer.close(hadoopContext) } committer.commitTask(hadoopContext) - bytesWrittenCallback.foreach { fn => outputMetrics.bytesWritten = fn() } + bytesWrittenCallback.foreach { fn => outputMetrics.setBytesWritten(fn()) } + outputMetrics.setRecordsWritten(recordsWritten) 1 } : Int @@ -1061,12 +1062,12 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) // around by taking a mod. We expect that no task will be attempted 2 billion times. val taskAttemptId = (context.taskAttemptId % Int.MaxValue).toInt - val (outputMetrics, bytesWrittenCallback) = initHadoopOutputMetrics(context, config) + val (outputMetrics, bytesWrittenCallback) = initHadoopOutputMetrics(context) writer.setup(context.stageId, context.partitionId, taskAttemptId) writer.open() + var recordsWritten = 0L try { - var recordsWritten = 0L while (iter.hasNext) { val record = iter.next() writer.write(record._1.asInstanceOf[AnyRef], record._2.asInstanceOf[AnyRef]) @@ -1079,18 +1080,16 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) writer.close() } writer.commit() - bytesWrittenCallback.foreach { fn => outputMetrics.bytesWritten = fn() } + bytesWrittenCallback.foreach { fn => outputMetrics.setBytesWritten(fn()) } + outputMetrics.setRecordsWritten(recordsWritten) } self.context.runJob(self, writeToFile) writer.commitJob() } - private def initHadoopOutputMetrics(context: TaskContext, config: Configuration) - : (OutputMetrics, Option[() => Long]) = { - val bytesWrittenCallback = Option(config.get("mapreduce.output.fileoutputformat.outputdir")) - .map(new Path(_)) - .flatMap(SparkHadoopUtil.get.getFSBytesWrittenOnThreadCallback(_, config)) + private def initHadoopOutputMetrics(context: TaskContext): (OutputMetrics, Option[() => Long]) = { + val bytesWrittenCallback = SparkHadoopUtil.get.getFSBytesWrittenOnThreadCallback() val outputMetrics = new OutputMetrics(DataWriteMethod.Hadoop) if (bytesWrittenCallback.isDefined) { context.taskMetrics.outputMetrics = Some(outputMetrics) @@ -1100,9 +1099,9 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) private def maybeUpdateOutputMetrics(bytesWrittenCallback: Option[() => Long], outputMetrics: OutputMetrics, recordsWritten: Long): Unit = { - if (recordsWritten % PairRDDFunctions.RECORDS_BETWEEN_BYTES_WRITTEN_METRIC_UPDATES == 0 - && bytesWrittenCallback.isDefined) { - bytesWrittenCallback.foreach { fn => outputMetrics.bytesWritten = fn() } + if (recordsWritten % PairRDDFunctions.RECORDS_BETWEEN_BYTES_WRITTEN_METRIC_UPDATES == 0) { + bytesWrittenCallback.foreach { fn => outputMetrics.setBytesWritten(fn()) } + outputMetrics.setRecordsWritten(recordsWritten) } } diff --git a/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala index 87b22de6ae697..f12d0cffaba34 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala @@ -111,7 +111,8 @@ private object ParallelCollectionRDD { /** * Slice a collection into numSlices sub-collections. One extra thing we do here is to treat Range * collections specially, encoding the slices as other Ranges to minimize memory cost. This makes - * it efficient to run Spark over RDDs representing large sets of numbers. + * it efficient to run Spark over RDDs representing large sets of numbers. And if the collection + * is an inclusive Range, we use inclusive range for the last slice. */ def slice[T: ClassTag](seq: Seq[T], numSlices: Int): Seq[Seq[T]] = { if (numSlices < 1) { @@ -127,19 +128,15 @@ private object ParallelCollectionRDD { }) } seq match { - case r: Range.Inclusive => { - val sign = if (r.step < 0) { - -1 - } else { - 1 - } - slice(new Range( - r.start, r.end + sign, r.step).asInstanceOf[Seq[T]], numSlices) - } case r: Range => { - positions(r.length, numSlices).map({ - case (start, end) => + positions(r.length, numSlices).zipWithIndex.map({ case ((start, end), index) => + // If the range is inclusive, use inclusive range for the last slice + if (r.isInclusive && index == numSlices - 1) { + new Range.Inclusive(r.start + start * r.step, r.end, r.step) + } + else { new Range(r.start + start * r.step, r.start + end * r.step, r.step) + } }).toSeq.asInstanceOf[Seq[Seq[T]]] } case nr: NumericRange[_] => { diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 5118e2b911120..cf0433010aa03 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -25,11 +25,8 @@ import scala.language.implicitConversions import scala.reflect.{classTag, ClassTag} import com.clearspring.analytics.stream.cardinality.HyperLogLogPlus -import org.apache.hadoop.io.BytesWritable +import org.apache.hadoop.io.{Writable, BytesWritable, NullWritable, Text} import org.apache.hadoop.io.compress.CompressionCodec -import org.apache.hadoop.io.NullWritable -import org.apache.hadoop.io.Text -import org.apache.hadoop.io.Writable import org.apache.hadoop.mapred.TextOutputFormat import org.apache.spark._ @@ -57,8 +54,7 @@ import org.apache.spark.util.random.{BernoulliSampler, PoissonSampler, Bernoulli * [[org.apache.spark.rdd.SequenceFileRDDFunctions]] contains operations available on RDDs that * can be saved as SequenceFiles. * All operations are automatically available on any RDD of the right type (e.g. RDD[(Int, Int)] - * through implicit conversions except `saveAsSequenceFile`. You need to - * `import org.apache.spark.SparkContext._` to make `saveAsSequenceFile` work. + * through implicit. * * Internally, each RDD is characterized by five main properties: * @@ -76,10 +72,27 @@ import org.apache.spark.util.random.{BernoulliSampler, PoissonSampler, Bernoulli * on RDD internals. */ abstract class RDD[T: ClassTag]( - @transient private var sc: SparkContext, + @transient private var _sc: SparkContext, @transient private var deps: Seq[Dependency[_]] ) extends Serializable with Logging { + if (classOf[RDD[_]].isAssignableFrom(elementClassTag.runtimeClass)) { + // This is a warning instead of an exception in order to avoid breaking user programs that + // might have defined nested RDDs without running jobs with them. + logWarning("Spark does not support nested RDDs (see SPARK-5063)") + } + + private def sc: SparkContext = { + if (_sc == null) { + throw new SparkException( + "RDD transformations and actions can only be invoked by the driver, not inside of other " + + "transformations; for example, rdd1.map(x => rdd2.values.count() * x) is invalid because " + + "the values transformation and count action cannot be performed inside of the rdd1.map " + + "transformation. For more information, see SPARK-5063.") + } + _sc + } + /** Construct an RDD with just a one-to-one dependency on one parent */ def this(@transient oneParent: RDD[_]) = this(oneParent.context , List(new OneToOneDependency(oneParent))) @@ -449,7 +462,13 @@ abstract class RDD[T: ClassTag]( * Return the union of this RDD and another one. Any identical elements will appear multiple * times (use `.distinct()` to eliminate them). */ - def union(other: RDD[T]): RDD[T] = new UnionRDD(sc, Array(this, other)) + def union(other: RDD[T]): RDD[T] = { + if (partitioner.isDefined && other.partitioner == partitioner) { + new PartitionerAwareUnionRDD(sc, Array(this, other)) + } else { + new UnionRDD(sc, Array(this, other)) + } + } /** * Return the union of this RDD and another one. Any identical elements will appear multiple @@ -587,8 +606,8 @@ abstract class RDD[T: ClassTag]( * print line function (like out.println()) as the 2nd parameter. * An example of pipe the RDD data of groupBy() in a streaming way, * instead of constructing a huge String to concat all the elements: - * def printRDDElement(record:(String, Seq[String]), f:String=>Unit) = - * for (e <- record._2){f(e)} + * def printRDDElement(record:(String, Seq[String]), f:String=>Unit) = + * for (e <- record._2){f(e)} * @param separateWorkingDir Use separate working directories for each task. * @return the result RDD */ @@ -824,7 +843,7 @@ abstract class RDD[T: ClassTag]( * Return an RDD with the elements from `this` that are not in `other`. * * Uses `this` partitioner/partition size, because even if `other` is huge, the resulting - * RDD will be <= us. + * RDD will be <= us. */ def subtract(other: RDD[T]): RDD[T] = subtract(other, partitioner.getOrElse(new HashPartitioner(partitions.size))) @@ -883,6 +902,38 @@ abstract class RDD[T: ClassTag]( jobResult.getOrElse(throw new UnsupportedOperationException("empty collection")) } + /** + * Reduces the elements of this RDD in a multi-level tree pattern. + * + * @param depth suggested depth of the tree (default: 2) + * @see [[org.apache.spark.rdd.RDD#reduce]] + */ + def treeReduce(f: (T, T) => T, depth: Int = 2): T = { + require(depth >= 1, s"Depth must be greater than or equal to 1 but got $depth.") + val cleanF = context.clean(f) + val reducePartition: Iterator[T] => Option[T] = iter => { + if (iter.hasNext) { + Some(iter.reduceLeft(cleanF)) + } else { + None + } + } + val partiallyReduced = mapPartitions(it => Iterator(reducePartition(it))) + val op: (Option[T], Option[T]) => Option[T] = (c, x) => { + if (c.isDefined && x.isDefined) { + Some(cleanF(c.get, x.get)) + } else if (c.isDefined) { + c + } else if (x.isDefined) { + x + } else { + None + } + } + partiallyReduced.treeAggregate(Option.empty[T])(op, op, depth) + .getOrElse(throw new UnsupportedOperationException("empty collection")) + } + /** * Aggregate the elements of each partition, and then the results for all the partitions, using a * given associative function and a neutral "zero value". The function op(t1, t2) is allowed to @@ -918,6 +969,37 @@ abstract class RDD[T: ClassTag]( jobResult } + /** + * Aggregates the elements of this RDD in a multi-level tree pattern. + * + * @param depth suggested depth of the tree (default: 2) + * @see [[org.apache.spark.rdd.RDD#aggregate]] + */ + def treeAggregate[U: ClassTag](zeroValue: U)( + seqOp: (U, T) => U, + combOp: (U, U) => U, + depth: Int = 2): U = { + require(depth >= 1, s"Depth must be greater than or equal to 1 but got $depth.") + if (partitions.size == 0) { + return Utils.clone(zeroValue, context.env.closureSerializer.newInstance()) + } + val cleanSeqOp = context.clean(seqOp) + val cleanCombOp = context.clean(combOp) + val aggregatePartition = (it: Iterator[T]) => it.aggregate(zeroValue)(cleanSeqOp, cleanCombOp) + var partiallyAggregated = mapPartitions(it => Iterator(aggregatePartition(it))) + var numPartitions = partiallyAggregated.partitions.size + val scale = math.max(math.ceil(math.pow(numPartitions, 1.0 / depth)).toInt, 2) + // If creating an extra level doesn't help reduce the wall-clock time, we stop tree aggregation. + while (numPartitions > scale + numPartitions / scale) { + numPartitions /= scale + val curNumPartitions = numPartitions + partiallyAggregated = partiallyAggregated.mapPartitionsWithIndex { (i, iter) => + iter.map((i % curNumPartitions, _)) + }.reduceByKey(new HashPartitioner(curNumPartitions), cleanCombOp).values + } + partiallyAggregated.reduce(cleanCombOp) + } + /** * Return the number of elements in the RDD. */ @@ -947,7 +1029,7 @@ abstract class RDD[T: ClassTag]( * * Note that this method should only be used if the resulting map is expected to be small, as * the whole thing is loaded into the driver's memory. - * To handle very large results, consider using rdd.map(x => (x, 1L)).reduceByKey(_ + _), which + * To handle very large results, consider using rdd.map(x => (x, 1L)).reduceByKey(_ + _), which * returns an RDD[T, Long] instead of a map. */ def countByValue()(implicit ord: Ordering[T] = null): Map[T, Long] = { @@ -985,7 +1067,7 @@ abstract class RDD[T: ClassTag]( * Algorithmic Engineering of a State of The Art Cardinality Estimation Algorithm", available * here. * - * The relative accuracy is approximately `1.054 / sqrt(2^p)`. Setting a nonzero `sp > p` + * The relative accuracy is approximately `1.054 / sqrt(2^p)`. Setting a nonzero `sp > p` * would trigger sparse representation of registers, which may reduce the memory consumption * and increase accuracy when the cardinality is small. * @@ -1064,6 +1146,9 @@ abstract class RDD[T: ClassTag]( * Take the first num elements of the RDD. It works by first scanning one partition, and use the * results from that partition to estimate the number of additional partitions needed to satisfy * the limit. + * + * @note due to complications in the internal implementation, this method will raise + * an exception if called on an RDD of `Nothing` or `Null`. */ def take(num: Int): Array[T] = { if (num == 0) { @@ -1175,6 +1260,16 @@ abstract class RDD[T: ClassTag]( * */ def min()(implicit ord: Ordering[T]): T = this.reduce(ord.min) + /** + * @note due to complications in the internal implementation, this method will raise an + * exception if called on an RDD of `Nothing` or `Null`. This may be come up in practice + * because, for example, the type of `parallelize(Seq())` is `RDD[Nothing]`. + * (`parallelize(Seq())` should be avoided anyway in favor of `parallelize(Seq[T]())`.) + * @return true if and only if the RDD contains no elements at all. Note that an RDD + * may be empty even when it has at least 1 partition. + */ + def isEmpty(): Boolean = partitions.length == 0 || take(1).length == 0 + /** * Save this RDD as a text file, using string representations of elements. */ @@ -1297,7 +1392,7 @@ abstract class RDD[T: ClassTag]( /** * Private API for changing an RDD's ClassTag. - * Used for internal Java <-> Scala API compatibility. + * Used for internal Java-Scala API compatibility. */ private[spark] def retag(cls: Class[T]): RDD[T] = { val classTag: ClassTag[T] = ClassTag.apply(cls) @@ -1306,7 +1401,7 @@ abstract class RDD[T: ClassTag]( /** * Private API for changing an RDD's ClassTag. - * Used for internal Java <-> Scala API compatibility. + * Used for internal Java-Scala API compatibility. */ private[spark] def retag(implicit classTag: ClassTag[T]): RDD[T] = { this.mapPartitions(identity, preservesPartitioning = true)(classTag) @@ -1441,7 +1536,7 @@ abstract class RDD[T: ClassTag]( */ object RDD { - // The following implicit functions were in SparkContext before 1.2 and users had to + // The following implicit functions were in SparkContext before 1.3 and users had to // `import SparkContext._` to enable them. Now we move them here to make the compiler find // them automatically. However, we still keep the old functions in SparkContext for backward // compatibility and forward to the following functions directly. @@ -1455,9 +1550,15 @@ object RDD { new AsyncRDDActions(rdd) } - implicit def rddToSequenceFileRDDFunctions[K <% Writable: ClassTag, V <% Writable: ClassTag]( - rdd: RDD[(K, V)]): SequenceFileRDDFunctions[K, V] = { - new SequenceFileRDDFunctions(rdd) + implicit def rddToSequenceFileRDDFunctions[K, V](rdd: RDD[(K, V)]) + (implicit kt: ClassTag[K], vt: ClassTag[V], + keyWritableFactory: WritableFactory[K], + valueWritableFactory: WritableFactory[V]) + : SequenceFileRDDFunctions[K, V] = { + implicit val keyConverter = keyWritableFactory.convert + implicit val valueConverter = valueWritableFactory.convert + new SequenceFileRDDFunctions(rdd, + keyWritableFactory.writableClass(kt), valueWritableFactory.writableClass(vt)) } implicit def rddToOrderedRDDFunctions[K : Ordering : ClassTag, V: ClassTag](rdd: RDD[(K, V)]) diff --git a/core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala index 2b48916951430..059f8963691f0 100644 --- a/core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala @@ -30,13 +30,35 @@ import org.apache.spark.Logging * through an implicit conversion. Note that this can't be part of PairRDDFunctions because * we need more implicit parameters to convert our keys and values to Writable. * - * Import `org.apache.spark.SparkContext._` at the top of their program to use these functions. */ class SequenceFileRDDFunctions[K <% Writable: ClassTag, V <% Writable : ClassTag]( - self: RDD[(K, V)]) + self: RDD[(K, V)], + _keyWritableClass: Class[_ <: Writable], + _valueWritableClass: Class[_ <: Writable]) extends Logging with Serializable { + @deprecated("It's used to provide backward compatibility for pre 1.3.0.", "1.3.0") + def this(self: RDD[(K, V)]) { + this(self, null, null) + } + + private val keyWritableClass = + if (_keyWritableClass == null) { + // pre 1.3.0, we need to use Reflection to get the Writable class + getWritableClass[K]() + } else { + _keyWritableClass + } + + private val valueWritableClass = + if (_valueWritableClass == null) { + // pre 1.3.0, we need to use Reflection to get the Writable class + getWritableClass[V]() + } else { + _valueWritableClass + } + private def getWritableClass[T <% Writable: ClassTag](): Class[_ <: Writable] = { val c = { if (classOf[Writable].isAssignableFrom(classTag[T].runtimeClass)) { @@ -55,6 +77,7 @@ class SequenceFileRDDFunctions[K <% Writable: ClassTag, V <% Writable : ClassTag c.asInstanceOf[Class[_ <: Writable]] } + /** * Output the RDD as a Hadoop SequenceFile using the Writable types we infer from the RDD's key * and value types. If the key or value are Writable, then we use their classes directly; @@ -65,26 +88,28 @@ class SequenceFileRDDFunctions[K <% Writable: ClassTag, V <% Writable : ClassTag def saveAsSequenceFile(path: String, codec: Option[Class[_ <: CompressionCodec]] = None) { def anyToWritable[U <% Writable](u: U): Writable = u - val keyClass = getWritableClass[K] - val valueClass = getWritableClass[V] - val convertKey = !classOf[Writable].isAssignableFrom(self.keyClass) - val convertValue = !classOf[Writable].isAssignableFrom(self.valueClass) + // TODO We cannot force the return type of `anyToWritable` be same as keyWritableClass and + // valueWritableClass at the compile time. To implement that, we need to add type parameters to + // SequenceFileRDDFunctions. however, SequenceFileRDDFunctions is a public class so it will be a + // breaking change. + val convertKey = self.keyClass != keyWritableClass + val convertValue = self.valueClass != valueWritableClass - logInfo("Saving as sequence file of type (" + keyClass.getSimpleName + "," + - valueClass.getSimpleName + ")" ) + logInfo("Saving as sequence file of type (" + keyWritableClass.getSimpleName + "," + + valueWritableClass.getSimpleName + ")" ) val format = classOf[SequenceFileOutputFormat[Writable, Writable]] val jobConf = new JobConf(self.context.hadoopConfiguration) if (!convertKey && !convertValue) { - self.saveAsHadoopFile(path, keyClass, valueClass, format, jobConf, codec) + self.saveAsHadoopFile(path, keyWritableClass, valueWritableClass, format, jobConf, codec) } else if (!convertKey && convertValue) { self.map(x => (x._1,anyToWritable(x._2))).saveAsHadoopFile( - path, keyClass, valueClass, format, jobConf, codec) + path, keyWritableClass, valueWritableClass, format, jobConf, codec) } else if (convertKey && !convertValue) { self.map(x => (anyToWritable(x._1),x._2)).saveAsHadoopFile( - path, keyClass, valueClass, format, jobConf, codec) + path, keyWritableClass, valueWritableClass, format, jobConf, codec) } else if (convertKey && convertValue) { self.map(x => (anyToWritable(x._1),anyToWritable(x._2))).saveAsHadoopFile( - path, keyClass, valueClass, format, jobConf, codec) + path, keyWritableClass, valueWritableClass, format, jobConf, codec) } } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 8cb15918baa8c..bc84e2351ad74 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -19,6 +19,7 @@ package org.apache.spark.scheduler import java.io.NotSerializableException import java.util.Properties +import java.util.concurrent.{TimeUnit, Executors} import java.util.concurrent.atomic.AtomicInteger import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map, Stack} @@ -28,8 +29,6 @@ import scala.language.postfixOps import scala.reflect.ClassTag import scala.util.control.NonFatal -import akka.actor._ -import akka.actor.SupervisorStrategy.Stop import akka.pattern.ask import akka.util.Timeout @@ -39,7 +38,7 @@ import org.apache.spark.executor.TaskMetrics import org.apache.spark.partial.{ApproximateActionListener, ApproximateEvaluator, PartialResult} import org.apache.spark.rdd.RDD import org.apache.spark.storage._ -import org.apache.spark.util.{CallSite, SystemClock, Clock, Utils} +import org.apache.spark.util._ import org.apache.spark.storage.BlockManagerMessages.BlockManagerHeartbeat /** @@ -64,11 +63,9 @@ class DAGScheduler( mapOutputTracker: MapOutputTrackerMaster, blockManagerMaster: BlockManagerMaster, env: SparkEnv, - clock: Clock = SystemClock) + clock: Clock = new SystemClock()) extends Logging { - import DAGScheduler._ - def this(sc: SparkContext, taskScheduler: TaskScheduler) = { this( sc, @@ -101,7 +98,13 @@ class DAGScheduler( private[scheduler] val activeJobs = new HashSet[ActiveJob] - // Contains the locations that each RDD's partitions are cached on + /** + * Contains the locations that each RDD's partitions are cached on. This map's keys are RDD ids + * and its values are arrays indexed by partition numbers. Each array value is the set of + * locations where that RDD partition is cached. + * + * All accesses to this map should be guarded by synchronizing on it (see SPARK-4454). + */ private val cacheLocs = new HashMap[Int, Array[Seq[TaskLocation]]] // For tracking failed nodes, we use the MapOutputTracker's epoch number, which is sent with @@ -112,14 +115,10 @@ class DAGScheduler( // stray messages to detect. private val failedEpoch = new HashMap[String, Long] - private val dagSchedulerActorSupervisor = - env.actorSystem.actorOf(Props(new DAGSchedulerActorSupervisor(this))) - // A closure serializer that we reuse. // This is only safe because DAGScheduler runs in a single thread. private val closureSerializer = SparkEnv.get.closureSerializer.newInstance() - private[scheduler] var eventProcessActor: ActorRef = _ /** If enabled, we may run certain actions like take() and first() locally. */ private val localExecutionEnabled = sc.getConf.getBoolean("spark.localExecution.enabled", false) @@ -127,27 +126,22 @@ class DAGScheduler( /** If enabled, FetchFailed will not cause stage retry, in order to surface the problem. */ private val disallowStageRetryForTest = sc.getConf.getBoolean("spark.test.noStageRetry", false) - private def initializeEventProcessActor() { - // blocking the thread until supervisor is started, which ensures eventProcessActor is - // not null before any job is submitted - implicit val timeout = Timeout(30 seconds) - val initEventActorReply = - dagSchedulerActorSupervisor ? Props(new DAGSchedulerEventProcessActor(this)) - eventProcessActor = Await.result(initEventActorReply, timeout.duration). - asInstanceOf[ActorRef] - } + private val messageScheduler = + Executors.newScheduledThreadPool(1, Utils.namedThreadFactory("dag-scheduler-message")) - initializeEventProcessActor() + private[scheduler] val eventProcessLoop = new DAGSchedulerEventProcessLoop(this) taskScheduler.setDAGScheduler(this) + private val outputCommitCoordinator = env.outputCommitCoordinator + // Called by TaskScheduler to report task's starting. def taskStarted(task: Task[_], taskInfo: TaskInfo) { - eventProcessActor ! BeginEvent(task, taskInfo) + eventProcessLoop.post(BeginEvent(task, taskInfo)) } // Called to report that a task has completed and results are being fetched remotely. def taskGettingResult(taskInfo: TaskInfo) { - eventProcessActor ! GettingResultEvent(taskInfo) + eventProcessLoop.post(GettingResultEvent(taskInfo)) } // Called by TaskScheduler to report task completions or failures. @@ -158,7 +152,8 @@ class DAGScheduler( accumUpdates: Map[Long, Any], taskInfo: TaskInfo, taskMetrics: TaskMetrics) { - eventProcessActor ! CompletionEvent(task, reason, result, accumUpdates, taskInfo, taskMetrics) + eventProcessLoop.post( + CompletionEvent(task, reason, result, accumUpdates, taskInfo, taskMetrics)) } /** @@ -180,21 +175,22 @@ class DAGScheduler( // Called by TaskScheduler when an executor fails. def executorLost(execId: String) { - eventProcessActor ! ExecutorLost(execId) + eventProcessLoop.post(ExecutorLost(execId)) } // Called by TaskScheduler when a host is added def executorAdded(execId: String, host: String) { - eventProcessActor ! ExecutorAdded(execId, host) + eventProcessLoop.post(ExecutorAdded(execId, host)) } // Called by TaskScheduler to cancel an entire TaskSet due to either repeated failures or // cancellation of the job itself. def taskSetFailed(taskSet: TaskSet, reason: String) { - eventProcessActor ! TaskSetFailed(taskSet, reason) + eventProcessLoop.post(TaskSetFailed(taskSet, reason)) } - private def getCacheLocs(rdd: RDD[_]): Array[Seq[TaskLocation]] = { + private def getCacheLocs(rdd: RDD[_]): Array[Seq[TaskLocation]] = cacheLocs.synchronized { + // Note: this doesn't use `getOrElse()` because this method is called O(num tasks) times if (!cacheLocs.contains(rdd.id)) { val blockIds = rdd.partitions.indices.map(index => RDDBlockId(rdd.id, index)).toArray[BlockId] val locs = BlockManager.blockIdsToBlockManagers(blockIds, env, blockManagerMaster) @@ -205,7 +201,7 @@ class DAGScheduler( cacheLocs(rdd.id) } - private def clearCacheLocs() { + private def clearCacheLocs(): Unit = cacheLocs.synchronized { cacheLocs.clear() } @@ -496,8 +492,8 @@ class DAGScheduler( assert(partitions.size > 0) val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _] val waiter = new JobWaiter(this, jobId, partitions.size, resultHandler) - eventProcessActor ! JobSubmitted( - jobId, rdd, func2, partitions.toArray, allowLocal, callSite, waiter, properties) + eventProcessLoop.post(JobSubmitted( + jobId, rdd, func2, partitions.toArray, allowLocal, callSite, waiter, properties)) waiter } @@ -537,8 +533,8 @@ class DAGScheduler( val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _] val partitions = (0 until rdd.partitions.size).toArray val jobId = nextJobId.getAndIncrement() - eventProcessActor ! JobSubmitted( - jobId, rdd, func2, partitions, allowLocal = false, callSite, listener, properties) + eventProcessLoop.post(JobSubmitted( + jobId, rdd, func2, partitions, allowLocal = false, callSite, listener, properties)) listener.awaitResult() // Will throw an exception if the job fails } @@ -547,19 +543,19 @@ class DAGScheduler( */ def cancelJob(jobId: Int) { logInfo("Asked to cancel job " + jobId) - eventProcessActor ! JobCancelled(jobId) + eventProcessLoop.post(JobCancelled(jobId)) } def cancelJobGroup(groupId: String) { logInfo("Asked to cancel job group " + groupId) - eventProcessActor ! JobGroupCancelled(groupId) + eventProcessLoop.post(JobGroupCancelled(groupId)) } /** * Cancel all jobs that are running or waiting in the queue. */ def cancelAllJobs() { - eventProcessActor ! AllJobsCancelled + eventProcessLoop.post(AllJobsCancelled) } private[scheduler] def doCancelAllJobs() { @@ -575,7 +571,7 @@ class DAGScheduler( * Cancel all jobs associated with a running or scheduled stage. */ def cancelStage(stageId: Int) { - eventProcessActor ! StageCancelled(stageId) + eventProcessLoop.post(StageCancelled(stageId)) } /** @@ -661,7 +657,7 @@ class DAGScheduler( // completion events or stage abort stageIdToStage -= s.id jobIdToStageIds -= job.jobId - listenerBus.post(SparkListenerJobEnd(job.jobId, jobResult)) + listenerBus.post(SparkListenerJobEnd(job.jobId, clock.getTimeMillis(), jobResult)) } } @@ -710,7 +706,7 @@ class DAGScheduler( stage.latestInfo.stageFailed(stageFailedMessage) listenerBus.post(SparkListenerStageCompleted(stage.latestInfo)) } - listenerBus.post(SparkListenerJobEnd(job.jobId, JobFailed(error))) + listenerBus.post(SparkListenerJobEnd(job.jobId, clock.getTimeMillis(), JobFailed(error))) } } @@ -749,9 +745,11 @@ class DAGScheduler( logInfo("Missing parents: " + getMissingParentStages(finalStage)) val shouldRunLocally = localExecutionEnabled && allowLocal && finalStage.parents.isEmpty && partitions.length == 1 + val jobSubmissionTime = clock.getTimeMillis() if (shouldRunLocally) { // Compute very short actions like first() or take() with no parent stages locally. - listenerBus.post(SparkListenerJobStart(job.jobId, Seq.empty, properties)) + listenerBus.post( + SparkListenerJobStart(job.jobId, jobSubmissionTime, Seq.empty, properties)) runLocally(job) } else { jobIdToActiveJob(jobId) = job @@ -759,7 +757,8 @@ class DAGScheduler( finalStage.resultOfJob = Some(job) val stageIds = jobIdToStageIds(jobId).toArray val stageInfos = stageIds.flatMap(id => stageIdToStage.get(id).map(_.latestInfo)) - listenerBus.post(SparkListenerJobStart(job.jobId, stageInfos, properties)) + listenerBus.post( + SparkListenerJobStart(job.jobId, jobSubmissionTime, stageInfos, properties)) submitStage(finalStage) } } @@ -818,6 +817,7 @@ class DAGScheduler( // will be posted, which should always come after a corresponding SparkListenerStageSubmitted // event. stage.latestInfo = StageInfo.fromStage(stage, Some(partitionsToCompute.size)) + outputCommitCoordinator.stageStart(stage.id) listenerBus.post(SparkListenerStageSubmitted(stage.latestInfo, properties)) // TODO: Maybe we can keep the taskBinary in Stage to avoid serializing it multiple times. @@ -871,10 +871,11 @@ class DAGScheduler( logDebug("New pending tasks: " + stage.pendingTasks) taskScheduler.submitTasks( new TaskSet(tasks.toArray, stage.id, stage.newAttemptId(), stage.jobId, properties)) - stage.latestInfo.submissionTime = Some(clock.getTime()) + stage.latestInfo.submissionTime = Some(clock.getTimeMillis()) } else { // Because we posted SparkListenerStageSubmitted earlier, we should post // SparkListenerStageCompleted here in case there are no tasks to run. + outputCommitCoordinator.stageEnd(stage.id) listenerBus.post(SparkListenerStageCompleted(stage.latestInfo)) logDebug("Stage " + stage + " is actually done; %b %d %d".format( stage.isAvailable, stage.numAvailableOutputs, stage.numPartitions)) @@ -889,8 +890,16 @@ class DAGScheduler( if (event.accumUpdates != null) { try { Accumulators.add(event.accumUpdates) + event.accumUpdates.foreach { case (id, partialValue) => - val acc = Accumulators.originals(id).asInstanceOf[Accumulable[Any, Any]] + // In this instance, although the reference in Accumulators.originals is a WeakRef, + // it's guaranteed to exist since the event.accumUpdates Map exists + + val acc = Accumulators.originals(id).get match { + case Some(accum) => accum.asInstanceOf[Accumulable[Any, Any]] + case None => throw new NullPointerException("Non-existent reference to Accumulator") + } + // To avoid UI cruft, ignore cases where value wasn't updated if (acc.name.isDefined && partialValue != acc.zero) { val name = acc.name.get @@ -919,6 +928,9 @@ class DAGScheduler( val stageId = task.stageId val taskType = Utils.getFormattedClassName(task) + outputCommitCoordinator.taskCompleted(stageId, task.partitionId, + event.taskInfo.attempt, event.reason) + // The success case is dealt with separately below, since we need to compute accumulator // updates before posting. if (event.reason != Success) { @@ -931,16 +943,17 @@ class DAGScheduler( // Skip all the actions if the stage has been cancelled. return } + val stage = stageIdToStage(task.stageId) def markStageAsFinished(stage: Stage, errorMessage: Option[String] = None) = { val serviceTime = stage.latestInfo.submissionTime match { - case Some(t) => "%.03f".format((clock.getTime() - t) / 1000.0) + case Some(t) => "%.03f".format((clock.getTimeMillis() - t) / 1000.0) case _ => "Unknown" } if (errorMessage.isEmpty) { logInfo("%s (%s) finished in %s s".format(stage, stage.name, serviceTime)) - stage.latestInfo.completionTime = Some(clock.getTime()) + stage.latestInfo.completionTime = Some(clock.getTimeMillis()) } else { stage.latestInfo.stageFailed(errorMessage.get) logInfo("%s (%s) failed in %s s".format(stage, stage.name, serviceTime)) @@ -965,7 +978,8 @@ class DAGScheduler( if (job.numFinished == job.numPartitions) { markStageAsFinished(stage) cleanupStateForJobAndIndependentStages(job) - listenerBus.post(SparkListenerJobEnd(job.jobId, JobSucceeded)) + listenerBus.post( + SparkListenerJobEnd(job.jobId, clock.getTimeMillis(), JobSucceeded)) } // taskSucceeded runs some user code that might throw an exception. Make sure @@ -1059,16 +1073,15 @@ class DAGScheduler( if (disallowStageRetryForTest) { abortStage(failedStage, "Fetch failure will not retry stage due to testing config") - } else if (failedStages.isEmpty && eventProcessActor != null) { + } else if (failedStages.isEmpty) { // Don't schedule an event to resubmit failed stages if failed isn't empty, because - // in that case the event will already have been scheduled. eventProcessActor may be - // null during unit tests. + // in that case the event will already have been scheduled. // TODO: Cancel running tasks in the stage - import env.actorSystem.dispatcher logInfo(s"Resubmitting $mapStage (${mapStage.name}) and " + s"$failedStage (${failedStage.name}) due to fetch failure") - env.actorSystem.scheduler.scheduleOnce( - RESUBMIT_TIMEOUT, eventProcessActor, ResubmitFailedStages) + messageScheduler.schedule(new Runnable { + override def run(): Unit = eventProcessLoop.post(ResubmitFailedStages) + }, DAGScheduler.RESUBMIT_TIMEOUT, TimeUnit.MILLISECONDS) } failedStages += failedStage failedStages += mapStage @@ -1083,6 +1096,9 @@ class DAGScheduler( handleExecutorLost(bmAddress.executorId, fetchFailed = true, Some(task.epoch)) } + case commitDenied: TaskCommitDenied => + // Do nothing here, left up to the TaskScheduler to decide how to handle denied commits + case ExceptionFailure(className, description, stackTrace, fullStackTrace, metrics) => // Do nothing here, left up to the TaskScheduler to decide how to handle user failures @@ -1179,7 +1195,7 @@ class DAGScheduler( } val dependentJobs: Seq[ActiveJob] = activeJobs.filter(job => stageDependsOn(job.finalStage, failedStage)).toSeq - failedStage.latestInfo.completionTime = Some(clock.getTime()) + failedStage.latestInfo.completionTime = Some(clock.getTimeMillis()) for (job <- dependentJobs) { failJobAndIndependentStages(job, s"Job aborted due to stage failure: $reason") } @@ -1234,7 +1250,7 @@ class DAGScheduler( if (ableToCancelStages) { job.listener.jobFailed(error) cleanupStateForJobAndIndependentStages(job) - listenerBus.post(SparkListenerJobEnd(job.jobId, JobFailed(error))) + listenerBus.post(SparkListenerJobEnd(job.jobId, clock.getTimeMillis(), JobFailed(error))) } } @@ -1275,17 +1291,26 @@ class DAGScheduler( } /** - * Synchronized method that might be called from other threads. + * Gets the locality information associated with a partition of a particular RDD. + * + * This method is thread-safe and is called from both DAGScheduler and SparkContext. + * * @param rdd whose partitions are to be looked at * @param partition to lookup locality information for * @return list of machines that are preferred by the partition */ private[spark] - def getPreferredLocs(rdd: RDD[_], partition: Int): Seq[TaskLocation] = synchronized { + def getPreferredLocs(rdd: RDD[_], partition: Int): Seq[TaskLocation] = { getPreferredLocsInternal(rdd, partition, new HashSet) } - /** Recursive implementation for getPreferredLocs. */ + /** + * Recursive implementation for getPreferredLocs. + * + * This method is thread-safe because it only accesses DAGScheduler state through thread-safe + * methods (getCacheLocs()); please be careful when modifying this method, because any new + * DAGScheduler state accessed by it may require additional synchronization. + */ private def getPreferredLocsInternal( rdd: RDD[_], partition: Int, @@ -1326,40 +1351,21 @@ class DAGScheduler( def stop() { logInfo("Stopping DAGScheduler") - dagSchedulerActorSupervisor ! PoisonPill + eventProcessLoop.stop() taskScheduler.stop() } -} -private[scheduler] class DAGSchedulerActorSupervisor(dagScheduler: DAGScheduler) - extends Actor with Logging { - - override val supervisorStrategy = - OneForOneStrategy() { - case x: Exception => - logError("eventProcesserActor failed; shutting down SparkContext", x) - try { - dagScheduler.doCancelAllJobs() - } catch { - case t: Throwable => logError("DAGScheduler failed to cancel all jobs.", t) - } - dagScheduler.sc.stop() - Stop - } - - def receive = { - case p: Props => sender ! context.actorOf(p) - case _ => logWarning("received unknown message in DAGSchedulerActorSupervisor") - } + // Start the event thread at the end of the constructor + eventProcessLoop.start() } -private[scheduler] class DAGSchedulerEventProcessActor(dagScheduler: DAGScheduler) - extends Actor with Logging { +private[scheduler] class DAGSchedulerEventProcessLoop(dagScheduler: DAGScheduler) + extends EventLoop[DAGSchedulerEvent]("dag-scheduler-event-loop") with Logging { /** * The main event loop of the DAG scheduler. */ - def receive = { + override def onReceive(event: DAGSchedulerEvent): Unit = event match { case JobSubmitted(jobId, rdd, func, partitions, allowLocal, callSite, listener, properties) => dagScheduler.handleJobSubmitted(jobId, rdd, func, partitions, allowLocal, callSite, listener, properties) @@ -1398,7 +1404,17 @@ private[scheduler] class DAGSchedulerEventProcessActor(dagScheduler: DAGSchedule dagScheduler.resubmitFailedStages() } - override def postStop() { + override def onError(e: Throwable): Unit = { + logError("DAGSchedulerEventProcessLoop failed; shutting down SparkContext", e) + try { + dagScheduler.doCancelAllJobs() + } catch { + case t: Throwable => logError("DAGScheduler failed to cancel all jobs.", t) + } + dagScheduler.sc.stop() + } + + override def onStop() { // Cancel any active jobs in postStop hook dagScheduler.cleanUpAfterSchedulerStop() } @@ -1408,9 +1424,5 @@ private[spark] object DAGScheduler { // The time, in millis, to wait for fetch failure events to stop coming in after one is detected; // this is a simplistic way to avoid resubmitting tasks in the non-fetchable map stage one by one // as more failure events come in - val RESUBMIT_TIMEOUT = 200.milliseconds - - // The time, in millis, to wake up between polls of the completion queue in order to potentially - // resubmit failed stages - val POLL_TIMEOUT = 10L + val RESUBMIT_TIMEOUT = 200 } diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala index 3bb54855bae44..8aa528ac573d0 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala @@ -169,7 +169,8 @@ class JobLogger(val user: String, val logDirName: String) extends SparkListener " BLOCK_FETCHED_LOCAL=" + metrics.localBlocksFetched + " BLOCK_FETCHED_REMOTE=" + metrics.remoteBlocksFetched + " REMOTE_FETCH_WAIT_TIME=" + metrics.fetchWaitTime + - " REMOTE_BYTES_READ=" + metrics.remoteBytesRead + " REMOTE_BYTES_READ=" + metrics.remoteBytesRead + + " LOCAL_BYTES_READ=" + metrics.localBytesRead case None => "" } val writeMetrics = taskMetrics.shuffleWriteMetrics match { diff --git a/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala index 36a6e6338faa6..be23056e7d423 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala @@ -17,10 +17,9 @@ package org.apache.spark.scheduler -import java.util.concurrent.{LinkedBlockingQueue, Semaphore} +import java.util.concurrent.atomic.AtomicBoolean -import org.apache.spark.Logging -import org.apache.spark.util.Utils +import org.apache.spark.util.AsynchronousListenerBus /** * Asynchronously passes SparkListenerEvents to registered SparkListeners. @@ -29,113 +28,19 @@ import org.apache.spark.util.Utils * has started will events be actually propagated to all attached listeners. This listener bus * is stopped when it receives a SparkListenerShutdown event, which is posted using stop(). */ -private[spark] class LiveListenerBus extends SparkListenerBus with Logging { - - /* Cap the capacity of the SparkListenerEvent queue so we get an explicit error (rather than - * an OOM exception) if it's perpetually being added to more quickly than it's being drained. */ - private val EVENT_QUEUE_CAPACITY = 10000 - private val eventQueue = new LinkedBlockingQueue[SparkListenerEvent](EVENT_QUEUE_CAPACITY) - private var queueFullErrorMessageLogged = false - private var started = false - - // A counter that represents the number of events produced and consumed in the queue - private val eventLock = new Semaphore(0) - - private val listenerThread = new Thread("SparkListenerBus") { - setDaemon(true) - override def run(): Unit = Utils.logUncaughtExceptions { - while (true) { - eventLock.acquire() - // Atomically remove and process this event - LiveListenerBus.this.synchronized { - val event = eventQueue.poll - if (event == SparkListenerShutdown) { - // Get out of the while loop and shutdown the daemon thread - return - } - Option(event).foreach(postToAll) - } - } - } - } - - /** - * Start sending events to attached listeners. - * - * This first sends out all buffered events posted before this listener bus has started, then - * listens for any additional events asynchronously while the listener bus is still running. - * This should only be called once. - */ - def start() { - if (started) { - throw new IllegalStateException("Listener bus already started!") +private[spark] class LiveListenerBus + extends AsynchronousListenerBus[SparkListener, SparkListenerEvent]("SparkListenerBus") + with SparkListenerBus { + + private val logDroppedEvent = new AtomicBoolean(false) + + override def onDropEvent(event: SparkListenerEvent): Unit = { + if (logDroppedEvent.compareAndSet(false, true)) { + // Only log the following message once to avoid duplicated annoying logs. + logError("Dropping SparkListenerEvent because no remaining room in event queue. " + + "This likely means one of the SparkListeners is too slow and cannot keep up with " + + "the rate at which tasks are being started by the scheduler.") } - listenerThread.start() - started = true } - def post(event: SparkListenerEvent) { - val eventAdded = eventQueue.offer(event) - if (eventAdded) { - eventLock.release() - } else { - logQueueFullErrorMessage() - } - } - - /** - * For testing only. Wait until there are no more events in the queue, or until the specified - * time has elapsed. Return true if the queue has emptied and false is the specified time - * elapsed before the queue emptied. - */ - def waitUntilEmpty(timeoutMillis: Int): Boolean = { - val finishTime = System.currentTimeMillis + timeoutMillis - while (!queueIsEmpty) { - if (System.currentTimeMillis > finishTime) { - return false - } - /* Sleep rather than using wait/notify, because this is used only for testing and - * wait/notify add overhead in the general case. */ - Thread.sleep(10) - } - true - } - - /** - * For testing only. Return whether the listener daemon thread is still alive. - */ - def listenerThreadIsAlive: Boolean = synchronized { listenerThread.isAlive } - - /** - * Return whether the event queue is empty. - * - * The use of synchronized here guarantees that all events that once belonged to this queue - * have already been processed by all attached listeners, if this returns true. - */ - def queueIsEmpty: Boolean = synchronized { eventQueue.isEmpty } - - /** - * Log an error message to indicate that the event queue is full. Do this only once. - */ - private def logQueueFullErrorMessage(): Unit = { - if (!queueFullErrorMessageLogged) { - if (listenerThread.isAlive) { - logError("Dropping SparkListenerEvent because no remaining room in event queue. " + - "This likely means one of the SparkListeners is too slow and cannot keep up with" + - "the rate at which tasks are being started by the scheduler.") - } else { - logError("SparkListenerBus thread is dead! This means SparkListenerEvents have not" + - "been (and will no longer be) propagated to listeners for some time.") - } - queueFullErrorMessageLogged = true - } - } - - def stop() { - if (!started) { - throw new IllegalStateException("Attempted to stop a listener bus that has not yet started!") - } - post(SparkListenerShutdown) - listenerThread.join() - } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala b/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala new file mode 100644 index 0000000000000..759df023a6dcf --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala @@ -0,0 +1,172 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +import scala.collection.mutable + +import akka.actor.{ActorRef, Actor} + +import org.apache.spark._ +import org.apache.spark.util.{AkkaUtils, ActorLogReceive} + +private sealed trait OutputCommitCoordinationMessage extends Serializable + +private case object StopCoordinator extends OutputCommitCoordinationMessage +private case class AskPermissionToCommitOutput(stage: Int, task: Long, taskAttempt: Long) + +/** + * Authority that decides whether tasks can commit output to HDFS. Uses a "first committer wins" + * policy. + * + * OutputCommitCoordinator is instantiated in both the drivers and executors. On executors, it is + * configured with a reference to the driver's OutputCommitCoordinatorActor, so requests to commit + * output will be forwarded to the driver's OutputCommitCoordinator. + * + * This class was introduced in SPARK-4879; see that JIRA issue (and the associated pull requests) + * for an extensive design discussion. + */ +private[spark] class OutputCommitCoordinator(conf: SparkConf) extends Logging { + + // Initialized by SparkEnv + var coordinatorActor: Option[ActorRef] = None + private val timeout = AkkaUtils.askTimeout(conf) + private val maxAttempts = AkkaUtils.numRetries(conf) + private val retryInterval = AkkaUtils.retryWaitMs(conf) + + private type StageId = Int + private type PartitionId = Long + private type TaskAttemptId = Long + + /** + * Map from active stages's id => partition id => task attempt with exclusive lock on committing + * output for that partition. + * + * Entries are added to the top-level map when stages start and are removed they finish + * (either successfully or unsuccessfully). + * + * Access to this map should be guarded by synchronizing on the OutputCommitCoordinator instance. + */ + private val authorizedCommittersByStage: CommittersByStageMap = mutable.Map() + private type CommittersByStageMap = mutable.Map[StageId, mutable.Map[PartitionId, TaskAttemptId]] + + /** + * Called by tasks to ask whether they can commit their output to HDFS. + * + * If a task attempt has been authorized to commit, then all other attempts to commit the same + * task will be denied. If the authorized task attempt fails (e.g. due to its executor being + * lost), then a subsequent task attempt may be authorized to commit its output. + * + * @param stage the stage number + * @param partition the partition number + * @param attempt a unique identifier for this task attempt + * @return true if this task is authorized to commit, false otherwise + */ + def canCommit( + stage: StageId, + partition: PartitionId, + attempt: TaskAttemptId): Boolean = { + val msg = AskPermissionToCommitOutput(stage, partition, attempt) + coordinatorActor match { + case Some(actor) => + AkkaUtils.askWithReply[Boolean](msg, actor, maxAttempts, retryInterval, timeout) + case None => + logError( + "canCommit called after coordinator was stopped (is SparkEnv shutdown in progress)?") + false + } + } + + // Called by DAGScheduler + private[scheduler] def stageStart(stage: StageId): Unit = synchronized { + authorizedCommittersByStage(stage) = mutable.HashMap[PartitionId, TaskAttemptId]() + } + + // Called by DAGScheduler + private[scheduler] def stageEnd(stage: StageId): Unit = synchronized { + authorizedCommittersByStage.remove(stage) + } + + // Called by DAGScheduler + private[scheduler] def taskCompleted( + stage: StageId, + partition: PartitionId, + attempt: TaskAttemptId, + reason: TaskEndReason): Unit = synchronized { + val authorizedCommitters = authorizedCommittersByStage.getOrElse(stage, { + logDebug(s"Ignoring task completion for completed stage") + return + }) + reason match { + case Success => + // The task output has been committed successfully + case denied: TaskCommitDenied => + logInfo( + s"Task was denied committing, stage: $stage, partition: $partition, attempt: $attempt") + case otherReason => + logDebug(s"Authorized committer $attempt (stage=$stage, partition=$partition) failed;" + + s" clearing lock") + authorizedCommitters.remove(partition) + } + } + + def stop(): Unit = synchronized { + coordinatorActor.foreach(_ ! StopCoordinator) + coordinatorActor = None + authorizedCommittersByStage.clear() + } + + // Marked private[scheduler] instead of private so this can be mocked in tests + private[scheduler] def handleAskPermissionToCommit( + stage: StageId, + partition: PartitionId, + attempt: TaskAttemptId): Boolean = synchronized { + authorizedCommittersByStage.get(stage) match { + case Some(authorizedCommitters) => + authorizedCommitters.get(partition) match { + case Some(existingCommitter) => + logDebug(s"Denying $attempt to commit for stage=$stage, partition=$partition; " + + s"existingCommitter = $existingCommitter") + false + case None => + logDebug(s"Authorizing $attempt to commit for stage=$stage, partition=$partition") + authorizedCommitters(partition) = attempt + true + } + case None => + logDebug(s"Stage $stage has completed, so not allowing task attempt $attempt to commit") + false + } + } +} + +private[spark] object OutputCommitCoordinator { + + // This actor is used only for RPC + class OutputCommitCoordinatorActor(outputCommitCoordinator: OutputCommitCoordinator) + extends Actor with ActorLogReceive with Logging { + + override def receiveWithLogging = { + case AskPermissionToCommitOutput(stage, partition, taskAttempt) => + sender ! outputCommitCoordinator.handleAskPermissionToCommit(stage, partition, taskAttempt) + case StopCoordinator => + logInfo("OutputCommitCoordinator stopped!") + context.stop(self) + sender ! true + } + } +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala index 584f4e7789d1a..d9c3a10dc5413 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala @@ -40,21 +40,24 @@ private[spark] class ReplayListenerBus extends SparkListenerBus with Logging { * * @param logData Stream containing event log data. * @param version Spark version that generated the events. + * @param sourceName Filename (or other source identifier) from whence @logData is being read */ - def replay(logData: InputStream, version: String) { + def replay(logData: InputStream, version: String, sourceName: String) { var currentLine: String = null + var lineNumber: Int = 1 try { val lines = Source.fromInputStream(logData).getLines() lines.foreach { line => currentLine = line postToAll(JsonProtocol.sparkEventFromJson(parse(line))) + lineNumber += 1 } } catch { case ioe: IOException => throw ioe case e: Exception => - logError("Exception in parsing Spark event log.", e) - logError("Malformed line: %s\n".format(currentLine)) + logError(s"Exception parsing Spark event log: $sourceName", e) + logError(s"Malformed line #$lineNumber: $currentLine\n") } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala index 4840d8bd2d2f0..dd28ddb31de1f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala @@ -59,6 +59,7 @@ case class SparkListenerTaskEnd( @DeveloperApi case class SparkListenerJobStart( jobId: Int, + time: Long, stageInfos: Seq[StageInfo], properties: Properties = null) extends SparkListenerEvent { @@ -68,7 +69,11 @@ case class SparkListenerJobStart( } @DeveloperApi -case class SparkListenerJobEnd(jobId: Int, jobResult: JobResult) extends SparkListenerEvent +case class SparkListenerJobEnd( + jobId: Int, + time: Long, + jobResult: JobResult) + extends SparkListenerEvent @DeveloperApi case class SparkListenerEnvironmentUpdate(environmentDetails: Map[String, Seq[(String, String)]]) @@ -86,11 +91,11 @@ case class SparkListenerBlockManagerRemoved(time: Long, blockManagerId: BlockMan case class SparkListenerUnpersistRDD(rddId: Int) extends SparkListenerEvent @DeveloperApi -case class SparkListenerExecutorAdded(executorId: String, executorInfo: ExecutorInfo) +case class SparkListenerExecutorAdded(time: Long, executorId: String, executorInfo: ExecutorInfo) extends SparkListenerEvent @DeveloperApi -case class SparkListenerExecutorRemoved(executorId: String) +case class SparkListenerExecutorRemoved(time: Long, executorId: String, reason: String) extends SparkListenerEvent /** @@ -111,9 +116,6 @@ case class SparkListenerApplicationStart(appName: String, appId: Option[String], @DeveloperApi case class SparkListenerApplicationEnd(time: Long) extends SparkListenerEvent -/** An event used in the listener to shutdown the listener daemon thread. */ -private[spark] case object SparkListenerShutdown extends SparkListenerEvent - /** * :: DeveloperApi :: diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala index e700c6af542f4..fe8a19a2c0cb9 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala @@ -17,78 +17,47 @@ package org.apache.spark.scheduler -import scala.collection.mutable -import scala.collection.mutable.ArrayBuffer - -import org.apache.spark.Logging -import org.apache.spark.util.Utils +import org.apache.spark.util.ListenerBus /** - * A SparkListenerEvent bus that relays events to its listeners + * A [[SparkListenerEvent]] bus that relays [[SparkListenerEvent]]s to its listeners */ -private[spark] trait SparkListenerBus extends Logging { - - // SparkListeners attached to this event bus - protected val sparkListeners = new ArrayBuffer[SparkListener] - with mutable.SynchronizedBuffer[SparkListener] - - def addListener(listener: SparkListener) { - sparkListeners += listener - } +private[spark] trait SparkListenerBus extends ListenerBus[SparkListener, SparkListenerEvent] { - /** - * Post an event to all attached listeners. - * This does nothing if the event is SparkListenerShutdown. - */ - def postToAll(event: SparkListenerEvent) { + override def onPostEvent(listener: SparkListener, event: SparkListenerEvent): Unit = { event match { case stageSubmitted: SparkListenerStageSubmitted => - foreachListener(_.onStageSubmitted(stageSubmitted)) + listener.onStageSubmitted(stageSubmitted) case stageCompleted: SparkListenerStageCompleted => - foreachListener(_.onStageCompleted(stageCompleted)) + listener.onStageCompleted(stageCompleted) case jobStart: SparkListenerJobStart => - foreachListener(_.onJobStart(jobStart)) + listener.onJobStart(jobStart) case jobEnd: SparkListenerJobEnd => - foreachListener(_.onJobEnd(jobEnd)) + listener.onJobEnd(jobEnd) case taskStart: SparkListenerTaskStart => - foreachListener(_.onTaskStart(taskStart)) + listener.onTaskStart(taskStart) case taskGettingResult: SparkListenerTaskGettingResult => - foreachListener(_.onTaskGettingResult(taskGettingResult)) + listener.onTaskGettingResult(taskGettingResult) case taskEnd: SparkListenerTaskEnd => - foreachListener(_.onTaskEnd(taskEnd)) + listener.onTaskEnd(taskEnd) case environmentUpdate: SparkListenerEnvironmentUpdate => - foreachListener(_.onEnvironmentUpdate(environmentUpdate)) + listener.onEnvironmentUpdate(environmentUpdate) case blockManagerAdded: SparkListenerBlockManagerAdded => - foreachListener(_.onBlockManagerAdded(blockManagerAdded)) + listener.onBlockManagerAdded(blockManagerAdded) case blockManagerRemoved: SparkListenerBlockManagerRemoved => - foreachListener(_.onBlockManagerRemoved(blockManagerRemoved)) + listener.onBlockManagerRemoved(blockManagerRemoved) case unpersistRDD: SparkListenerUnpersistRDD => - foreachListener(_.onUnpersistRDD(unpersistRDD)) + listener.onUnpersistRDD(unpersistRDD) case applicationStart: SparkListenerApplicationStart => - foreachListener(_.onApplicationStart(applicationStart)) + listener.onApplicationStart(applicationStart) case applicationEnd: SparkListenerApplicationEnd => - foreachListener(_.onApplicationEnd(applicationEnd)) + listener.onApplicationEnd(applicationEnd) case metricsUpdate: SparkListenerExecutorMetricsUpdate => - foreachListener(_.onExecutorMetricsUpdate(metricsUpdate)) + listener.onExecutorMetricsUpdate(metricsUpdate) case executorAdded: SparkListenerExecutorAdded => - foreachListener(_.onExecutorAdded(executorAdded)) + listener.onExecutorAdded(executorAdded) case executorRemoved: SparkListenerExecutorRemoved => - foreachListener(_.onExecutorRemoved(executorRemoved)) - case SparkListenerShutdown => - } - } - - /** - * Apply the given function to all attached listeners, catching and logging any exception. - */ - private def foreachListener(f: SparkListener => Unit): Unit = { - sparkListeners.foreach { listener => - try { - f(listener) - } catch { - case e: Exception => - logError(s"Listener ${Utils.getFormattedClassName(listener)} threw an exception", e) - } + listener.onExecutorRemoved(executorRemoved) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index 2367f7e2cf67e..847a4912eec13 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -55,7 +55,7 @@ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) ex context = new TaskContextImpl(stageId = stageId, partitionId = partitionId, taskAttemptId = taskAttemptId, attemptNumber = attemptNumber, runningLocally = false) TaskContextHelper.setTaskContext(context) - context.taskMetrics.hostname = Utils.localHostName() + context.taskMetrics.setHostname(Utils.localHostName()) taskThread = Thread.currentThread() if (_killed) { kill(interruptThread = false) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala index 4896ec845bbc9..3938580aeea59 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala @@ -18,6 +18,7 @@ package org.apache.spark.scheduler import java.nio.ByteBuffer +import java.util.concurrent.RejectedExecutionException import scala.language.existentials import scala.util.control.NonFatal @@ -77,7 +78,7 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul (deserializedResult, size) } - result.metrics.resultSize = size + result.metrics.setResultSize(size) scheduler.handleSuccessfulTask(taskSetManager, tid, result) } catch { case cnf: ClassNotFoundException => @@ -95,25 +96,30 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul def enqueueFailedTask(taskSetManager: TaskSetManager, tid: Long, taskState: TaskState, serializedData: ByteBuffer) { var reason : TaskEndReason = UnknownReason - getTaskResultExecutor.execute(new Runnable { - override def run(): Unit = Utils.logUncaughtExceptions { - try { - if (serializedData != null && serializedData.limit() > 0) { - reason = serializer.get().deserialize[TaskEndReason]( - serializedData, Utils.getSparkClassLoader) + try { + getTaskResultExecutor.execute(new Runnable { + override def run(): Unit = Utils.logUncaughtExceptions { + try { + if (serializedData != null && serializedData.limit() > 0) { + reason = serializer.get().deserialize[TaskEndReason]( + serializedData, Utils.getSparkClassLoader) + } + } catch { + case cnd: ClassNotFoundException => + // Log an error but keep going here -- the task failed, so not catastrophic + // if we can't deserialize the reason. + val loader = Utils.getContextOrSparkClassLoader + logError( + "Could not deserialize TaskEndReason: ClassNotFound with classloader " + loader) + case ex: Exception => {} } - } catch { - case cnd: ClassNotFoundException => - // Log an error but keep going here -- the task failed, so not catastrophic if we can't - // deserialize the reason. - val loader = Utils.getContextOrSparkClassLoader - logError( - "Could not deserialize TaskEndReason: ClassNotFound with classloader " + loader) - case ex: Exception => {} + scheduler.handleFailedTask(taskSetManager, tid, taskState, reason) } - scheduler.handleFailedTask(taskSetManager, tid, taskState, reason) - } - }) + }) + } catch { + case e: RejectedExecutionException if sparkEnv.isStopped => + // ignore it + } } def stop() { diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index a1dfb01062591..54f8fcfc416d1 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -158,7 +158,7 @@ private[spark] class TaskSchedulerImpl( val tasks = taskSet.tasks logInfo("Adding task set " + taskSet.id + " with " + tasks.length + " tasks") this.synchronized { - val manager = new TaskSetManager(this, taskSet, maxTaskFailures) + val manager = createTaskSetManager(taskSet, maxTaskFailures) activeTaskSets(taskSet.id) = manager schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties) @@ -168,7 +168,7 @@ private[spark] class TaskSchedulerImpl( if (!hasLaunchedTask) { logWarning("Initial job has not accepted any resources; " + "check your cluster UI to ensure that workers are registered " + - "and have sufficient memory") + "and have sufficient resources") } else { this.cancel() } @@ -180,6 +180,13 @@ private[spark] class TaskSchedulerImpl( backend.reviveOffers() } + // Label as private[scheduler] to allow tests to swap in different task set managers if necessary + private[scheduler] def createTaskSetManager( + taskSet: TaskSet, + maxTaskFailures: Int): TaskSetManager = { + new TaskSetManager(this, taskSet, maxTaskFailures) + } + override def cancelTasks(stageId: Int, interruptThread: Boolean): Unit = synchronized { logInfo("Cancelling stage " + stageId) activeTaskSets.find(_._2.stageId == stageId).foreach { case (_, tsm) => @@ -361,7 +368,7 @@ private[spark] class TaskSchedulerImpl( dagScheduler.executorHeartbeatReceived(execId, metricsWithStageIds, blockManagerId) } - def handleTaskGettingResult(taskSetManager: TaskSetManager, tid: Long) { + def handleTaskGettingResult(taskSetManager: TaskSetManager, tid: Long): Unit = synchronized { taskSetManager.handleTaskGettingResult(tid) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 5c94c6bbcb37b..529237f0d35dc 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -51,7 +51,7 @@ private[spark] class TaskSetManager( sched: TaskSchedulerImpl, val taskSet: TaskSet, val maxTaskFailures: Int, - clock: Clock = SystemClock) + clock: Clock = new SystemClock()) extends Schedulable with Logging { val conf = sched.sc.conf @@ -166,7 +166,7 @@ private[spark] class TaskSetManager( // last launched a task at that level, and move up a level when localityWaits[curLevel] expires. // We then move down if we manage to launch a "more local" task. var currentLocalityIndex = 0 // Index of our current locality level in validLocalityLevels - var lastLaunchTime = clock.getTime() // Time we last launched a task at this level + var lastLaunchTime = clock.getTimeMillis() // Time we last launched a task at this level override def schedulableQueue = null @@ -281,7 +281,7 @@ private[spark] class TaskSetManager( val failed = failedExecutors.get(taskId).get return failed.contains(execId) && - clock.getTime() - failed.get(execId).get < EXECUTOR_TASK_BLACKLIST_TIMEOUT + clock.getTimeMillis() - failed.get(execId).get < EXECUTOR_TASK_BLACKLIST_TIMEOUT } false @@ -292,7 +292,8 @@ private[spark] class TaskSetManager( * an attempt running on this host, in case the host is slow. In addition, the task should meet * the given locality constraint. */ - private def dequeueSpeculativeTask(execId: String, host: String, locality: TaskLocality.Value) + // Labeled as protected to allow tests to override providing speculative tasks if necessary + protected def dequeueSpeculativeTask(execId: String, host: String, locality: TaskLocality.Value) : Option[(Int, TaskLocality.Value)] = { speculatableTasks.retain(index => !successful(index)) // Remove finished tasks from set @@ -427,7 +428,7 @@ private[spark] class TaskSetManager( : Option[TaskDescription] = { if (!isZombie) { - val curTime = clock.getTime() + val curTime = clock.getTimeMillis() var allowedLocality = maxLocality @@ -458,7 +459,7 @@ private[spark] class TaskSetManager( lastLaunchTime = curTime } // Serialize and return the task - val startTime = clock.getTime() + val startTime = clock.getTimeMillis() val serializedTask: ByteBuffer = try { Task.serializeWithDependencies(task, sched.sc.addedFiles, sched.sc.addedJars, ser) } catch { @@ -506,13 +507,64 @@ private[spark] class TaskSetManager( * Get the level we can launch tasks according to delay scheduling, based on current wait time. */ private def getAllowedLocalityLevel(curTime: Long): TaskLocality.TaskLocality = { - while (curTime - lastLaunchTime >= localityWaits(currentLocalityIndex) && - currentLocalityIndex < myLocalityLevels.length - 1) - { - // Jump to the next locality level, and remove our waiting time for the current one since - // we don't want to count it again on the next one - lastLaunchTime += localityWaits(currentLocalityIndex) - currentLocalityIndex += 1 + // Remove the scheduled or finished tasks lazily + def tasksNeedToBeScheduledFrom(pendingTaskIds: ArrayBuffer[Int]): Boolean = { + var indexOffset = pendingTaskIds.size + while (indexOffset > 0) { + indexOffset -= 1 + val index = pendingTaskIds(indexOffset) + if (copiesRunning(index) == 0 && !successful(index)) { + return true + } else { + pendingTaskIds.remove(indexOffset) + } + } + false + } + // Walk through the list of tasks that can be scheduled at each location and returns true + // if there are any tasks that still need to be scheduled. Lazily cleans up tasks that have + // already been scheduled. + def moreTasksToRunIn(pendingTasks: HashMap[String, ArrayBuffer[Int]]): Boolean = { + val emptyKeys = new ArrayBuffer[String] + val hasTasks = pendingTasks.exists { + case (id: String, tasks: ArrayBuffer[Int]) => + if (tasksNeedToBeScheduledFrom(tasks)) { + true + } else { + emptyKeys += id + false + } + } + // The key could be executorId, host or rackId + emptyKeys.foreach(id => pendingTasks.remove(id)) + hasTasks + } + + while (currentLocalityIndex < myLocalityLevels.length - 1) { + val moreTasks = myLocalityLevels(currentLocalityIndex) match { + case TaskLocality.PROCESS_LOCAL => moreTasksToRunIn(pendingTasksForExecutor) + case TaskLocality.NODE_LOCAL => moreTasksToRunIn(pendingTasksForHost) + case TaskLocality.NO_PREF => pendingTasksWithNoPrefs.nonEmpty + case TaskLocality.RACK_LOCAL => moreTasksToRunIn(pendingTasksForRack) + } + if (!moreTasks) { + // This is a performance optimization: if there are no more tasks that can + // be scheduled at a particular locality level, there is no point in waiting + // for the locality wait timeout (SPARK-4939). + lastLaunchTime = curTime + logDebug(s"No tasks for locality level ${myLocalityLevels(currentLocalityIndex)}, " + + s"so moving to locality level ${myLocalityLevels(currentLocalityIndex + 1)}") + currentLocalityIndex += 1 + } else if (curTime - lastLaunchTime >= localityWaits(currentLocalityIndex)) { + // Jump to the next locality level, and reset lastLaunchTime so that the next locality + // wait timer doesn't immediately expire + lastLaunchTime += localityWaits(currentLocalityIndex) + currentLocalityIndex += 1 + logDebug(s"Moving to ${myLocalityLevels(currentLocalityIndex)} after waiting for " + + s"${localityWaits(currentLocalityIndex)}ms") + } else { + return myLocalityLevels(currentLocalityIndex) + } } myLocalityLevels(currentLocalityIndex) } @@ -542,7 +594,7 @@ private[spark] class TaskSetManager( /** * Check whether has enough quota to fetch the result with `size` bytes */ - def canFetchMoreResults(size: Long): Boolean = synchronized { + def canFetchMoreResults(size: Long): Boolean = sched.synchronized { totalResultSize += size calculatedTasks += 1 if (maxResultSize > 0 && totalResultSize > maxResultSize) { @@ -622,7 +674,7 @@ private[spark] class TaskSetManager( return } val key = ef.description - val now = clock.getTime() + val now = clock.getTimeMillis() val (printFull, dupCount) = { if (recentExceptions.contains(key)) { val (dupCount, printTime) = recentExceptions(key) @@ -654,10 +706,13 @@ private[spark] class TaskSetManager( } // always add to failed executors failedExecutors.getOrElseUpdate(index, new HashMap[String, Long]()). - put(info.executorId, clock.getTime()) + put(info.executorId, clock.getTimeMillis()) sched.dagScheduler.taskEnded(tasks(index), reason, null, null, info, taskMetrics) addPendingTask(index) - if (!isZombie && state != TaskState.KILLED) { + if (!isZombie && state != TaskState.KILLED && !reason.isInstanceOf[TaskCommitDenied]) { + // If a task failed because its attempt to commit was denied, do not count this failure + // towards failing the stage. This is intended to prevent spurious stage failures in cases + // where many speculative tasks are launched and denied to commit. assert (null != failureReason) numFailures(index) += 1 if (numFailures(index) >= maxTaskFailures) { @@ -671,7 +726,7 @@ private[spark] class TaskSetManager( maybeFinishTaskSet() } - def abort(message: String) { + def abort(message: String): Unit = sched.synchronized { // TODO: Kill running tasks if we were not terminated due to a Mesos error sched.dagScheduler.taskSetFailed(taskSet, message) isZombie = true @@ -766,7 +821,7 @@ private[spark] class TaskSetManager( val minFinishedForSpeculation = (SPECULATION_QUANTILE * numTasks).floor.toInt logDebug("Checking for speculative tasks: minFinished = " + minFinishedForSpeculation) if (tasksSuccessful >= minFinishedForSpeculation && tasksSuccessful > 0) { - val time = clock.getTime() + val time = clock.getTimeMillis() val durations = taskInfos.values.filter(_.successful).map(_.duration).toArray Arrays.sort(durations) val medianDuration = durations(min((0.5 * tasksSuccessful).round.toInt, durations.size - 1)) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala index 1da6fe976da5b..9bf74f4be198d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala @@ -39,7 +39,11 @@ private[spark] object CoarseGrainedClusterMessages { case class RegisterExecutorFailed(message: String) extends CoarseGrainedClusterMessage // Executors to driver - case class RegisterExecutor(executorId: String, hostPort: String, cores: Int) + case class RegisterExecutor( + executorId: String, + hostPort: String, + cores: Int, + logUrls: Map[String, String]) extends CoarseGrainedClusterMessage { Utils.checkHostPort(hostPort, "Expected host port") } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 5786d367464f4..6f77fa32ce37b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -86,7 +86,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste } def receiveWithLogging = { - case RegisterExecutor(executorId, hostPort, cores) => + case RegisterExecutor(executorId, hostPort, cores, logUrls) => Utils.checkHostPort(hostPort, "Host port expected " + hostPort) if (executorDataMap.contains(executorId)) { sender ! RegisterExecutorFailed("Duplicate executor ID: " + executorId) @@ -98,7 +98,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste totalCoreCount.addAndGet(cores) totalRegisteredExecutors.addAndGet(1) val (host, _) = Utils.parseHostPort(hostPort) - val data = new ExecutorData(sender, sender.path.address, host, cores, cores) + val data = new ExecutorData(sender, sender.path.address, host, cores, cores, logUrls) // This must be synchronized because variables mutated // in this block are read when requesting executors CoarseGrainedSchedulerBackend.this.synchronized { @@ -108,7 +108,8 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste logDebug(s"Decremented number of pending executors ($numPendingExecutors left)") } } - listenerBus.post(SparkListenerExecutorAdded(executorId, data)) + listenerBus.post( + SparkListenerExecutorAdded(System.currentTimeMillis(), executorId, data)) makeOffers() } @@ -216,7 +217,8 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste totalCoreCount.addAndGet(-executorInfo.totalCores) totalRegisteredExecutors.addAndGet(-1) scheduler.executorLost(executorId, SlaveLost(reason)) - listenerBus.post(SparkListenerExecutorRemoved(executorId)) + listenerBus.post( + SparkListenerExecutorRemoved(System.currentTimeMillis(), executorId, reason)) case None => logError(s"Asked to remove non-existent executor $executorId") } } @@ -309,9 +311,14 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste /** * Request an additional number of executors from the cluster manager. - * Return whether the request is acknowledged. + * @return whether the request is acknowledged. */ final override def requestExecutors(numAdditionalExecutors: Int): Boolean = synchronized { + if (numAdditionalExecutors < 0) { + throw new IllegalArgumentException( + "Attempted to request a negative number of additional executor(s) " + + s"$numAdditionalExecutors from the cluster manager. Please specify a positive number!") + } logInfo(s"Requesting $numAdditionalExecutors additional executor(s) from the cluster manager") logDebug(s"Number of pending executors is now $numPendingExecutors") numPendingExecutors += numAdditionalExecutors @@ -320,6 +327,22 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste doRequestTotalExecutors(newTotal) } + /** + * Express a preference to the cluster manager for a given total number of executors. This can + * result in canceling pending requests or filing additional requests. + * @return whether the request is acknowledged. + */ + final override def requestTotalExecutors(numExecutors: Int): Boolean = synchronized { + if (numExecutors < 0) { + throw new IllegalArgumentException( + "Attempted to request a negative number of executor(s) " + + s"$numExecutors from the cluster manager. Please specify a positive number!") + } + numPendingExecutors = + math.max(numExecutors - numExistingExecutors + executorsPendingToRemove.size, 0) + doRequestTotalExecutors(numExecutors) + } + /** * Request executors from the cluster manager by specifying the total number desired, * including existing pending and running executors. @@ -330,7 +353,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste * insufficient resources to satisfy the first request. We make the assumption here that the * cluster manager will eventually fulfill all requests when resources free up. * - * Return whether the request is acknowledged. + * @return whether the request is acknowledged. */ protected def doRequestTotalExecutors(requestedTotal: Int): Boolean = false diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala index eb52ddfb1eab1..5e571efe76720 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala @@ -33,5 +33,6 @@ private[cluster] class ExecutorData( val executorAddress: Address, override val executorHost: String, var freeCores: Int, - override val totalCores: Int -) extends ExecutorInfo(executorHost, totalCores) + override val totalCores: Int, + override val logUrlMap: Map[String, String] +) extends ExecutorInfo(executorHost, totalCores, logUrlMap) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorInfo.scala index b4738e64c9391..7f218566146a1 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorInfo.scala @@ -25,8 +25,8 @@ import org.apache.spark.annotation.DeveloperApi @DeveloperApi class ExecutorInfo( val executorHost: String, - val totalCores: Int -) { + val totalCores: Int, + val logUrlMap: Map[String, String]) { def canEqual(other: Any): Boolean = other.isInstanceOf[ExecutorInfo] @@ -34,12 +34,13 @@ class ExecutorInfo( case that: ExecutorInfo => (that canEqual this) && executorHost == that.executorHost && - totalCores == that.totalCores + totalCores == that.totalCores && + logUrlMap == that.logUrlMap case _ => false } override def hashCode(): Int = { - val state = Seq(executorHost, totalCores) + val state = Seq(executorHost, totalCores, logUrlMap) state.map(_.hashCode()).foldLeft(0)((a, b) => 31 * a + b) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala index ee10aa061f4e9..06786a59524e7 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala @@ -22,6 +22,7 @@ import org.apache.hadoop.fs.{Path, FileSystem} import org.apache.spark.{Logging, SparkContext, SparkEnv} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.scheduler.TaskSchedulerImpl +import org.apache.spark.util.AkkaUtils private[spark] class SimrSchedulerBackend( scheduler: TaskSchedulerImpl, @@ -38,7 +39,8 @@ private[spark] class SimrSchedulerBackend( override def start() { super.start() - val driverUrl = "akka.tcp://%s@%s:%s/user/%s".format( + val driverUrl = AkkaUtils.address( + AkkaUtils.protocol(actorSystem), SparkEnv.driverActorSystemName, sc.conf.get("spark.driver.host"), sc.conf.get("spark.driver.port"), diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala index 7eb87a564d6f5..a0aa555f6244f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala @@ -17,11 +17,13 @@ package org.apache.spark.scheduler.cluster +import java.util.concurrent.Semaphore + import org.apache.spark.{Logging, SparkConf, SparkContext, SparkEnv} import org.apache.spark.deploy.{ApplicationDescription, Command} import org.apache.spark.deploy.client.{AppClient, AppClientListener} import org.apache.spark.scheduler.{ExecutorExited, ExecutorLossReason, SlaveLost, TaskSchedulerImpl} -import org.apache.spark.util.Utils +import org.apache.spark.util.{AkkaUtils, Utils} private[spark] class SparkDeploySchedulerBackend( scheduler: TaskSchedulerImpl, @@ -31,28 +33,34 @@ private[spark] class SparkDeploySchedulerBackend( with AppClientListener with Logging { - var client: AppClient = null - var stopping = false - var shutdownCallback : (SparkDeploySchedulerBackend) => Unit = _ - @volatile var appId: String = _ + private var client: AppClient = null + private var stopping = false + + @volatile var shutdownCallback: SparkDeploySchedulerBackend => Unit = _ + @volatile private var appId: String = _ - val registrationLock = new Object() - var registrationDone = false + private val registrationBarrier = new Semaphore(0) - val maxCores = conf.getOption("spark.cores.max").map(_.toInt) - val totalExpectedCores = maxCores.getOrElse(0) + private val maxCores = conf.getOption("spark.cores.max").map(_.toInt) + private val totalExpectedCores = maxCores.getOrElse(0) override def start() { super.start() // The endpoint for executors to talk to us - val driverUrl = "akka.tcp://%s@%s:%s/user/%s".format( + val driverUrl = AkkaUtils.address( + AkkaUtils.protocol(actorSystem), SparkEnv.driverActorSystemName, conf.get("spark.driver.host"), conf.get("spark.driver.port"), CoarseGrainedSchedulerBackend.ACTOR_NAME) - val args = Seq(driverUrl, "{{EXECUTOR_ID}}", "{{HOSTNAME}}", "{{CORES}}", "{{APP_ID}}", - "{{WORKER_URL}}") + val args = Seq( + "--driver-url", driverUrl, + "--executor-id", "{{EXECUTOR_ID}}", + "--hostname", "{{HOSTNAME}}", + "--cores", "{{CORES}}", + "--app-id", "{{APP_ID}}", + "--worker-url", "{{WORKER_URL}}") val extraJavaOpts = sc.conf.getOption("spark.executor.extraJavaOptions") .map(Utils.splitCommandString).getOrElse(Seq.empty) val classPathEntries = sc.conf.getOption("spark.executor.extraClassPath") @@ -89,8 +97,10 @@ private[spark] class SparkDeploySchedulerBackend( stopping = true super.stop() client.stop() - if (shutdownCallback != null) { - shutdownCallback(this) + + val callback = shutdownCallback + if (callback != null) { + callback(this) } } @@ -143,18 +153,11 @@ private[spark] class SparkDeploySchedulerBackend( } private def waitForRegistration() = { - registrationLock.synchronized { - while (!registrationDone) { - registrationLock.wait() - } - } + registrationBarrier.acquire() } private def notifyContext() = { - registrationLock.synchronized { - registrationDone = true - registrationLock.notifyAll() - } + registrationBarrier.release() } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala index 5289661eb896b..90dfe14352a8e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala @@ -31,7 +31,7 @@ import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, TaskState => MesosTas import org.apache.spark.{Logging, SparkContext, SparkEnv, SparkException} import org.apache.spark.scheduler.TaskSchedulerImpl import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend -import org.apache.spark.util.Utils +import org.apache.spark.util.{Utils, AkkaUtils} /** * A SchedulerBackend that runs tasks on Mesos, but uses "coarse-grained" tasks, where it holds @@ -143,7 +143,8 @@ private[spark] class CoarseMesosSchedulerBackend( } val command = CommandInfo.newBuilder() .setEnvironment(environment) - val driverUrl = "akka.tcp://%s@%s:%s/user/%s".format( + val driverUrl = AkkaUtils.address( + AkkaUtils.protocol(sc.env.actorSystem), SparkEnv.driverActorSystemName, conf.get("spark.driver.host"), conf.get("spark.driver.port"), @@ -153,18 +154,25 @@ private[spark] class CoarseMesosSchedulerBackend( if (uri == null) { val runScript = new File(executorSparkHome, "./bin/spark-class").getCanonicalPath command.setValue( - "%s \"%s\" org.apache.spark.executor.CoarseGrainedExecutorBackend %s %s %s %d %s".format( - prefixEnv, runScript, driverUrl, offer.getSlaveId.getValue, - offer.getHostname, numCores, appId)) + "%s \"%s\" org.apache.spark.executor.CoarseGrainedExecutorBackend" + .format(prefixEnv, runScript) + + s" --driver-url $driverUrl" + + s" --executor-id ${offer.getSlaveId.getValue}" + + s" --hostname ${offer.getHostname}" + + s" --cores $numCores" + + s" --app-id $appId") } else { // Grab everything to the first '.'. We'll use that and '*' to // glob the directory "correctly". val basename = uri.split('/').last.split('.').head command.setValue( - ("cd %s*; %s " + - "./bin/spark-class org.apache.spark.executor.CoarseGrainedExecutorBackend %s %s %s %d %s") - .format(basename, prefixEnv, driverUrl, offer.getSlaveId.getValue, - offer.getHostname, numCores, appId)) + s"cd $basename*; $prefixEnv " + + "./bin/spark-class org.apache.spark.executor.CoarseGrainedExecutorBackend" + + s" --driver-url $driverUrl" + + s" --executor-id ${offer.getSlaveId.getValue}" + + s" --hostname ${offer.getHostname}" + + s" --cores $numCores" + + s" --app-id $appId") command.addUris(CommandInfo.URI.newBuilder().setValue(uri)) } command.build() diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala index d252fe8595fb8..cfb6592e14aa8 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala @@ -30,6 +30,7 @@ import org.apache.mesos._ import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, TaskState => MesosTaskState, ExecutorInfo => MesosExecutorInfo, _} +import org.apache.spark.executor.MesosExecutorBackend import org.apache.spark.{Logging, SparkContext, SparkException, TaskState} import org.apache.spark.scheduler.cluster.ExecutorInfo import org.apache.spark.scheduler._ @@ -123,14 +124,15 @@ private[spark] class MesosSchedulerBackend( val command = CommandInfo.newBuilder() .setEnvironment(environment) val uri = sc.conf.get("spark.executor.uri", null) + val executorBackendName = classOf[MesosExecutorBackend].getName if (uri == null) { - val executorPath = new File(executorSparkHome, "/sbin/spark-executor").getCanonicalPath - command.setValue("%s %s".format(prefixEnv, executorPath)) + val executorPath = new File(executorSparkHome, "/bin/spark-class").getCanonicalPath + command.setValue(s"$prefixEnv $executorPath $executorBackendName") } else { // Grab everything to the first '.'. We'll use that and '*' to // glob the directory "correctly". val basename = uri.split('/').last.split('.').head - command.setValue("cd %s*; %s ./sbin/spark-executor".format(basename, prefixEnv)) + command.setValue(s"cd ${basename}*; $prefixEnv ./bin/spark-class $executorBackendName") command.addUris(CommandInfo.URI.newBuilder().setValue(uri)) } val cpus = Resource.newBuilder() @@ -267,8 +269,9 @@ private[spark] class MesosSchedulerBackend( mesosTasks.foreach { case (slaveId, tasks) => slaveIdToWorkerOffer.get(slaveId).foreach(o => - listenerBus.post(SparkListenerExecutorAdded(slaveId, - new ExecutorInfo(o.host, o.cores))) + listenerBus.post(SparkListenerExecutorAdded(System.currentTimeMillis(), slaveId, + // TODO: Add support for log urls for Mesos + new ExecutorInfo(o.host, o.cores, Map.empty))) ) d.launchTasks(Collections.singleton(slaveIdToOffer(slaveId).getId), tasks, filters) } @@ -325,7 +328,7 @@ private[spark] class MesosSchedulerBackend( synchronized { if (status.getState == MesosTaskState.TASK_LOST && taskIdToSlaveId.contains(tid)) { // We lost the executor on this slave, so remember that it's gone - removeExecutor(taskIdToSlaveId(tid)) + removeExecutor(taskIdToSlaveId(tid), "Lost executor") } if (isFinished(status.getState)) { taskIdToSlaveId.remove(tid) @@ -357,9 +360,9 @@ private[spark] class MesosSchedulerBackend( /** * Remove executor associated with slaveId in a thread safe manner. */ - private def removeExecutor(slaveId: String) = { + private def removeExecutor(slaveId: String, reason: String) = { synchronized { - listenerBus.post(SparkListenerExecutorRemoved(slaveId)) + listenerBus.post(SparkListenerExecutorRemoved(System.currentTimeMillis(), slaveId, reason)) slaveIdsWithExecutors -= slaveId } } @@ -367,7 +370,7 @@ private[spark] class MesosSchedulerBackend( private def recordSlaveLost(d: SchedulerDriver, slaveId: SlaveID, reason: ExecutorLossReason) { inClassLoader() { logInfo("Mesos slave lost: " + slaveId.getValue) - removeExecutor(slaveId.getValue) + removeExecutor(slaveId.getValue, reason.toString) scheduler.executorLost(slaveId.getValue, reason) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosTaskLaunchData.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosTaskLaunchData.scala index 4416ce92ade25..5e7e6567a3e06 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosTaskLaunchData.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosTaskLaunchData.scala @@ -21,24 +21,29 @@ import java.nio.ByteBuffer import org.apache.mesos.protobuf.ByteString +import org.apache.spark.Logging + /** * Wrapper for serializing the data sent when launching Mesos tasks. */ private[spark] case class MesosTaskLaunchData( serializedTask: ByteBuffer, - attemptNumber: Int) { + attemptNumber: Int) extends Logging { def toByteString: ByteString = { val dataBuffer = ByteBuffer.allocate(4 + serializedTask.limit) dataBuffer.putInt(attemptNumber) dataBuffer.put(serializedTask) + dataBuffer.rewind + logDebug(s"ByteBuffer size: [${dataBuffer.remaining}]") ByteString.copyFrom(dataBuffer) } } -private[spark] object MesosTaskLaunchData { +private[spark] object MesosTaskLaunchData extends Logging { def fromByteString(byteString: ByteString): MesosTaskLaunchData = { val byteBuffer = byteString.asReadOnlyByteBuffer() + logDebug(s"ByteBuffer size: [${byteBuffer.remaining}]") val attemptNumber = byteBuffer.getInt // updates the position by 4 bytes val serializedTask = byteBuffer.slice() // subsequence starting at the current position MesosTaskLaunchData(serializedTask, attemptNumber) diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala index 05b6fa54564b7..4676b828d3d89 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala @@ -19,6 +19,8 @@ package org.apache.spark.scheduler.local import java.nio.ByteBuffer +import scala.concurrent.duration._ + import akka.actor.{Actor, ActorRef, Props} import org.apache.spark.{Logging, SparkContext, SparkEnv, TaskState} @@ -46,6 +48,8 @@ private[spark] class LocalActor( private val totalCores: Int) extends Actor with ActorLogReceive with Logging { + import context.dispatcher // to use Akka's scheduler.scheduleOnce() + private var freeCores = totalCores private val localExecutorId = SparkContext.DRIVER_IDENTIFIER @@ -74,11 +78,16 @@ private[spark] class LocalActor( def reviveOffers() { val offers = Seq(new WorkerOffer(localExecutorId, localExecutorHostname, freeCores)) - for (task <- scheduler.resourceOffers(offers).flatten) { + val tasks = scheduler.resourceOffers(offers).flatten + for (task <- tasks) { freeCores -= scheduler.CPUS_PER_TASK executor.launchTask(executorBackend, taskId = task.taskId, attemptNumber = task.attemptNumber, task.name, task.serializedTask) } + if (tasks.isEmpty && scheduler.activeTaskSets.nonEmpty) { + // Try to reviveOffer after 1 second, because scheduler may wait for locality timeout + context.system.scheduler.scheduleOnce(1000 millis, self, ReviveOffers) + } } } diff --git a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala index 662a7b91248aa..1baa0e009f3ae 100644 --- a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala @@ -27,7 +27,8 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.util.ByteBufferInputStream import org.apache.spark.util.Utils -private[spark] class JavaSerializationStream(out: OutputStream, counterReset: Int) +private[spark] class JavaSerializationStream( + out: OutputStream, counterReset: Int, extraDebugInfo: Boolean) extends SerializationStream { private val objOut = new ObjectOutputStream(out) private var counter = 0 @@ -39,7 +40,12 @@ private[spark] class JavaSerializationStream(out: OutputStream, counterReset: In * the stream 'resets' object class descriptions have to be re-written) */ def writeObject[T: ClassTag](t: T): SerializationStream = { - objOut.writeObject(t) + try { + objOut.writeObject(t) + } catch { + case e: NotSerializableException if extraDebugInfo => + throw SerializationDebugger.improveException(t, e) + } counter += 1 if (counterReset > 0 && counter >= counterReset) { objOut.reset() @@ -64,7 +70,8 @@ extends DeserializationStream { } -private[spark] class JavaSerializerInstance(counterReset: Int, defaultClassLoader: ClassLoader) +private[spark] class JavaSerializerInstance( + counterReset: Int, extraDebugInfo: Boolean, defaultClassLoader: ClassLoader) extends SerializerInstance { override def serialize[T: ClassTag](t: T): ByteBuffer = { @@ -88,11 +95,11 @@ private[spark] class JavaSerializerInstance(counterReset: Int, defaultClassLoade } override def serializeStream(s: OutputStream): SerializationStream = { - new JavaSerializationStream(s, counterReset) + new JavaSerializationStream(s, counterReset, extraDebugInfo) } override def deserializeStream(s: InputStream): DeserializationStream = { - new JavaDeserializationStream(s, Utils.getContextOrSparkClassLoader) + new JavaDeserializationStream(s, defaultClassLoader) } def deserializeStream(s: InputStream, loader: ClassLoader): DeserializationStream = { @@ -111,17 +118,20 @@ private[spark] class JavaSerializerInstance(counterReset: Int, defaultClassLoade @DeveloperApi class JavaSerializer(conf: SparkConf) extends Serializer with Externalizable { private var counterReset = conf.getInt("spark.serializer.objectStreamReset", 100) + private var extraDebugInfo = conf.getBoolean("spark.serializer.extraDebugInfo", true) override def newInstance(): SerializerInstance = { val classLoader = defaultClassLoader.getOrElse(Thread.currentThread.getContextClassLoader) - new JavaSerializerInstance(counterReset, classLoader) + new JavaSerializerInstance(counterReset, extraDebugInfo, classLoader) } override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException { out.writeInt(counterReset) + out.writeBoolean(extraDebugInfo) } override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException { counterReset = in.readInt() + extraDebugInfo = in.readBoolean() } } diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index d56e23ce4478a..02158aa0f866e 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -58,14 +58,6 @@ class KryoSerializer(conf: SparkConf) private val classesToRegister = conf.get("spark.kryo.classesToRegister", "") .split(',') .filter(!_.isEmpty) - .map { className => - try { - Class.forName(className) - } catch { - case e: Exception => - throw new SparkException("Failed to load class to register with Kryo", e) - } - } def newKryoOutput() = new KryoOutput(bufferSize, math.max(bufferSize, maxBufferSize)) @@ -97,7 +89,8 @@ class KryoSerializer(conf: SparkConf) // Use the default classloader when calling the user registrator. Thread.currentThread.setContextClassLoader(classLoader) // Register classes given through spark.kryo.classesToRegister. - classesToRegister.foreach { clazz => kryo.register(clazz) } + classesToRegister + .foreach { className => kryo.register(Class.forName(className, true, classLoader)) } // Allow the user to register their own classes by setting spark.kryo.registrator. userRegistrator .map(Class.forName(_, true, classLoader).newInstance().asInstanceOf[KryoRegistrator]) diff --git a/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala b/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala new file mode 100644 index 0000000000000..cecb992579655 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala @@ -0,0 +1,307 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.serializer + +import java.io.{NotSerializableException, ObjectOutput, ObjectStreamClass, ObjectStreamField} +import java.lang.reflect.{Field, Method} +import java.security.AccessController + +import scala.annotation.tailrec +import scala.collection.mutable + +import org.apache.spark.Logging + +private[serializer] object SerializationDebugger extends Logging { + + /** + * Improve the given NotSerializableException with the serialization path leading from the given + * object to the problematic object. This is turned off automatically if + * `sun.io.serialization.extendedDebugInfo` flag is turned on for the JVM. + */ + def improveException(obj: Any, e: NotSerializableException): NotSerializableException = { + if (enableDebugging && reflect != null) { + new NotSerializableException( + e.getMessage + "\nSerialization stack:\n" + find(obj).map("\t- " + _).mkString("\n")) + } else { + e + } + } + + /** + * Find the path leading to a not serializable object. This method is modeled after OpenJDK's + * serialization mechanism, and handles the following cases: + * - primitives + * - arrays of primitives + * - arrays of non-primitive objects + * - Serializable objects + * - Externalizable objects + * - writeReplace + * + * It does not yet handle writeObject override, but that shouldn't be too hard to do either. + */ + def find(obj: Any): List[String] = { + new SerializationDebugger().visit(obj, List.empty) + } + + private[serializer] var enableDebugging: Boolean = { + !AccessController.doPrivileged(new sun.security.action.GetBooleanAction( + "sun.io.serialization.extendedDebugInfo")).booleanValue() + } + + private class SerializationDebugger { + + /** A set to track the list of objects we have visited, to avoid cycles in the graph. */ + private val visited = new mutable.HashSet[Any] + + /** + * Visit the object and its fields and stop when we find an object that is not serializable. + * Return the path as a list. If everything can be serialized, return an empty list. + */ + def visit(o: Any, stack: List[String]): List[String] = { + if (o == null) { + List.empty + } else if (visited.contains(o)) { + List.empty + } else { + visited += o + o match { + // Primitive value, string, and primitive arrays are always serializable + case _ if o.getClass.isPrimitive => List.empty + case _: String => List.empty + case _ if o.getClass.isArray && o.getClass.getComponentType.isPrimitive => List.empty + + // Traverse non primitive array. + case a: Array[_] if o.getClass.isArray && !o.getClass.getComponentType.isPrimitive => + val elem = s"array (class ${a.getClass.getName}, size ${a.length})" + visitArray(o.asInstanceOf[Array[_]], elem :: stack) + + case e: java.io.Externalizable => + val elem = s"externalizable object (class ${e.getClass.getName}, $e)" + visitExternalizable(e, elem :: stack) + + case s: Object with java.io.Serializable => + val elem = s"object (class ${s.getClass.getName}, $s)" + visitSerializable(s, elem :: stack) + + case _ => + // Found an object that is not serializable! + s"object not serializable (class: ${o.getClass.getName}, value: $o)" :: stack + } + } + } + + private def visitArray(o: Array[_], stack: List[String]): List[String] = { + var i = 0 + while (i < o.length) { + val childStack = visit(o(i), s"element of array (index: $i)" :: stack) + if (childStack.nonEmpty) { + return childStack + } + i += 1 + } + return List.empty + } + + private def visitExternalizable(o: java.io.Externalizable, stack: List[String]): List[String] = + { + val fieldList = new ListObjectOutput + o.writeExternal(fieldList) + val childObjects = fieldList.outputArray + var i = 0 + while (i < childObjects.length) { + val childStack = visit(childObjects(i), "writeExternal data" :: stack) + if (childStack.nonEmpty) { + return childStack + } + i += 1 + } + return List.empty + } + + private def visitSerializable(o: Object, stack: List[String]): List[String] = { + // An object contains multiple slots in serialization. + // Get the slots and visit fields in all of them. + val (finalObj, desc) = findObjectAndDescriptor(o) + val slotDescs = desc.getSlotDescs + var i = 0 + while (i < slotDescs.length) { + val slotDesc = slotDescs(i) + if (slotDesc.hasWriteObjectMethod) { + // TODO: Handle classes that specify writeObject method. + } else { + val fields: Array[ObjectStreamField] = slotDesc.getFields + val objFieldValues: Array[Object] = new Array[Object](slotDesc.getNumObjFields) + val numPrims = fields.length - objFieldValues.length + desc.getObjFieldValues(finalObj, objFieldValues) + + var j = 0 + while (j < objFieldValues.length) { + val fieldDesc = fields(numPrims + j) + val elem = s"field (class: ${slotDesc.getName}" + + s", name: ${fieldDesc.getName}" + + s", type: ${fieldDesc.getType})" + val childStack = visit(objFieldValues(j), elem :: stack) + if (childStack.nonEmpty) { + return childStack + } + j += 1 + } + + } + i += 1 + } + return List.empty + } + } + + /** + * Find the object to serialize and the associated [[ObjectStreamClass]]. This method handles + * writeReplace in Serializable. It starts with the object itself, and keeps calling the + * writeReplace method until there is no more + */ + @tailrec + private def findObjectAndDescriptor(o: Object): (Object, ObjectStreamClass) = { + val cl = o.getClass + val desc = ObjectStreamClass.lookupAny(cl) + if (!desc.hasWriteReplaceMethod) { + (o, desc) + } else { + // write place + findObjectAndDescriptor(desc.invokeWriteReplace(o)) + } + } + + /** + * A dummy [[ObjectOutput]] that simply saves the list of objects written by a writeExternal + * call, and returns them through `outputArray`. + */ + private class ListObjectOutput extends ObjectOutput { + private val output = new mutable.ArrayBuffer[Any] + def outputArray: Array[Any] = output.toArray + override def writeObject(o: Any): Unit = output += o + override def flush(): Unit = {} + override def write(i: Int): Unit = {} + override def write(bytes: Array[Byte]): Unit = {} + override def write(bytes: Array[Byte], i: Int, i1: Int): Unit = {} + override def close(): Unit = {} + override def writeFloat(v: Float): Unit = {} + override def writeChars(s: String): Unit = {} + override def writeDouble(v: Double): Unit = {} + override def writeUTF(s: String): Unit = {} + override def writeShort(i: Int): Unit = {} + override def writeInt(i: Int): Unit = {} + override def writeBoolean(b: Boolean): Unit = {} + override def writeBytes(s: String): Unit = {} + override def writeChar(i: Int): Unit = {} + override def writeLong(l: Long): Unit = {} + override def writeByte(i: Int): Unit = {} + } + + /** An implicit class that allows us to call private methods of ObjectStreamClass. */ + implicit class ObjectStreamClassMethods(val desc: ObjectStreamClass) extends AnyVal { + def getSlotDescs: Array[ObjectStreamClass] = { + reflect.GetClassDataLayout.invoke(desc).asInstanceOf[Array[Object]].map { + classDataSlot => reflect.DescField.get(classDataSlot).asInstanceOf[ObjectStreamClass] + } + } + + def hasWriteObjectMethod: Boolean = { + reflect.HasWriteObjectMethod.invoke(desc).asInstanceOf[Boolean] + } + + def hasWriteReplaceMethod: Boolean = { + reflect.HasWriteReplaceMethod.invoke(desc).asInstanceOf[Boolean] + } + + def invokeWriteReplace(obj: Object): Object = { + reflect.InvokeWriteReplace.invoke(desc, obj) + } + + def getNumObjFields: Int = { + reflect.GetNumObjFields.invoke(desc).asInstanceOf[Int] + } + + def getObjFieldValues(obj: Object, out: Array[Object]): Unit = { + reflect.GetObjFieldValues.invoke(desc, obj, out) + } + } + + /** + * Object to hold all the reflection objects. If we run on a JVM that we cannot understand, + * this field will be null and this the debug helper should be disabled. + */ + private val reflect: ObjectStreamClassReflection = try { + new ObjectStreamClassReflection + } catch { + case e: Exception => + logWarning("Cannot find private methods using reflection", e) + null + } + + private class ObjectStreamClassReflection { + /** ObjectStreamClass.getClassDataLayout */ + val GetClassDataLayout: Method = { + val f = classOf[ObjectStreamClass].getDeclaredMethod("getClassDataLayout") + f.setAccessible(true) + f + } + + /** ObjectStreamClass.hasWriteObjectMethod */ + val HasWriteObjectMethod: Method = { + val f = classOf[ObjectStreamClass].getDeclaredMethod("hasWriteObjectMethod") + f.setAccessible(true) + f + } + + /** ObjectStreamClass.hasWriteReplaceMethod */ + val HasWriteReplaceMethod: Method = { + val f = classOf[ObjectStreamClass].getDeclaredMethod("hasWriteReplaceMethod") + f.setAccessible(true) + f + } + + /** ObjectStreamClass.invokeWriteReplace */ + val InvokeWriteReplace: Method = { + val f = classOf[ObjectStreamClass].getDeclaredMethod("invokeWriteReplace", classOf[Object]) + f.setAccessible(true) + f + } + + /** ObjectStreamClass.getNumObjFields */ + val GetNumObjFields: Method = { + val f = classOf[ObjectStreamClass].getDeclaredMethod("getNumObjFields") + f.setAccessible(true) + f + } + + /** ObjectStreamClass.getObjFieldValues */ + val GetObjFieldValues: Method = { + val f = classOf[ObjectStreamClass].getDeclaredMethod( + "getObjFieldValues", classOf[Object], classOf[Array[Object]]) + f.setAccessible(true) + f + } + + /** ObjectStreamClass$ClassDataSlot.desc field */ + val DescField: Field = { + val f = Class.forName("java.io.ObjectStreamClass$ClassDataSlot").getDeclaredField("desc") + f.setAccessible(true) + f + } + } +} diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala index e3e7434df45b0..7a2c5ae32d98b 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala @@ -86,6 +86,12 @@ private[hash] object BlockStoreShuffleFetcher extends Logging { context.taskMetrics.updateShuffleReadMetrics() }) - new InterruptibleIterator[T](context, completionIter) + new InterruptibleIterator[T](context, completionIter) { + val readMetrics = context.taskMetrics.createShuffleReadMetricsForDependency() + override def next(): T = { + readMetrics.incRecordsRead(1) + delegate.next() + } + } } } diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala index de72148ccc7ac..41bafabde05b9 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala @@ -59,8 +59,8 @@ private[spark] class HashShuffleReader[K, C]( // the ExternalSorter won't spill to disk. val sorter = new ExternalSorter[K, C, C](ordering = Some(keyOrd), serializer = Some(ser)) sorter.insertAll(aggregatedIter) - context.taskMetrics.memoryBytesSpilled += sorter.memoryBytesSpilled - context.taskMetrics.diskBytesSpilled += sorter.diskBytesSpilled + context.taskMetrics.incMemoryBytesSpilled(sorter.memoryBytesSpilled) + context.taskMetrics.incDiskBytesSpilled(sorter.diskBytesSpilled) sorter.iterator case None => aggregatedIter diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 8bc5a1cd18b64..86dbd89f0ffb8 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -53,7 +53,7 @@ private[spark] class BlockResult( readMethod: DataReadMethod.Value, bytes: Long) { val inputMetrics = new InputMetrics(readMethod) - inputMetrics.addBytesRead(bytes) + inputMetrics.incBytesRead(bytes) } /** diff --git a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala index 9c469370ffe1f..81164178b9e8e 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala @@ -29,7 +29,8 @@ import org.apache.spark.executor.ShuffleWriteMetrics * appending data to an existing block, and can guarantee atomicity in the case of faults * as it allows the caller to revert partial writes. * - * This interface does not support concurrent writes. + * This interface does not support concurrent writes. Also, once the writer has + * been opened, it cannot be reopened again. */ private[spark] abstract class BlockObjectWriter(val blockId: BlockId) { @@ -95,6 +96,7 @@ private[spark] class DiskBlockObjectWriter( private var ts: TimeTrackingOutputStream = null private var objOut: SerializationStream = null private var initialized = false + private var hasBeenClosed = false /** * Cursors used to represent positions in the file. @@ -115,11 +117,16 @@ private[spark] class DiskBlockObjectWriter( private var finalPosition: Long = -1 private var reportedPosition = initialPosition - /** Calling channel.position() to update the write metrics can be a little bit expensive, so we - * only call it every N writes */ - private var writesSinceMetricsUpdate = 0 + /** + * Keep track of number of records written and also use this to periodically + * output bytes written since the latter is expensive to do for each record. + */ + private var numRecordsWritten = 0 override def open(): BlockObjectWriter = { + if (hasBeenClosed) { + throw new IllegalStateException("Writer already closed. Cannot be reopened.") + } fos = new FileOutputStream(file, true) ts = new TimeTrackingOutputStream(fos) channel = fos.getChannel() @@ -145,6 +152,7 @@ private[spark] class DiskBlockObjectWriter( ts = null objOut = null initialized = false + hasBeenClosed = true } } @@ -160,14 +168,15 @@ private[spark] class DiskBlockObjectWriter( } finalPosition = file.length() // In certain compression codecs, more bytes are written after close() is called - writeMetrics.shuffleBytesWritten += (finalPosition - reportedPosition) + writeMetrics.incShuffleBytesWritten(finalPosition - reportedPosition) } // Discard current writes. We do this by flushing the outstanding writes and then // truncating the file to its initial position. override def revertPartialWritesAndClose() { try { - writeMetrics.shuffleBytesWritten -= (reportedPosition - initialPosition) + writeMetrics.decShuffleBytesWritten(reportedPosition - initialPosition) + writeMetrics.decShuffleRecordsWritten(numRecordsWritten) if (initialized) { objOut.flush() @@ -193,12 +202,11 @@ private[spark] class DiskBlockObjectWriter( } objOut.writeObject(value) + numRecordsWritten += 1 + writeMetrics.incShuffleRecordsWritten(1) - if (writesSinceMetricsUpdate == 32) { - writesSinceMetricsUpdate = 0 + if (numRecordsWritten % 32 == 0) { updateBytesWritten() - } else { - writesSinceMetricsUpdate += 1 } } @@ -212,14 +220,14 @@ private[spark] class DiskBlockObjectWriter( */ private def updateBytesWritten() { val pos = channel.position() - writeMetrics.shuffleBytesWritten += (pos - reportedPosition) + writeMetrics.incShuffleBytesWritten(pos - reportedPosition) reportedPosition = pos } private def callWithTiming(f: => Unit) = { val start = System.nanoTime() f - writeMetrics.shuffleWriteTime += (System.nanoTime() - start) + writeMetrics.incShuffleWriteTime(System.nanoTime() - start) } // For testing diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala index af05eb3ca69ce..12cd8ea3bdf1f 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala @@ -17,9 +17,8 @@ package org.apache.spark.storage +import java.util.UUID import java.io.{IOException, File} -import java.text.SimpleDateFormat -import java.util.{Date, Random, UUID} import org.apache.spark.{SparkConf, Logging} import org.apache.spark.executor.ExecutorExitCode @@ -37,7 +36,6 @@ import org.apache.spark.util.Utils private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkConf) extends Logging { - private val MAX_DIR_CREATION_ATTEMPTS: Int = 10 private[spark] val subDirsPerLocalDir = blockManager.conf.getInt("spark.diskStore.subDirectories", 64) @@ -51,7 +49,7 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon } private val subDirs = Array.fill(localDirs.length)(new Array[File](subDirsPerLocalDir)) - addShutdownHook() + private val shutdownHook = addShutdownHook() /** Looks up a file by hashing it into one of our local subdirectories. */ // This method should be kept in sync with @@ -123,48 +121,42 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon } private def createLocalDirs(conf: SparkConf): Array[File] = { - val dateFormat = new SimpleDateFormat("yyyyMMddHHmmss") Utils.getOrCreateLocalRootDirs(conf).flatMap { rootDir => - var foundLocalDir = false - var localDir: File = null - var localDirId: String = null - var tries = 0 - val rand = new Random() - while (!foundLocalDir && tries < MAX_DIR_CREATION_ATTEMPTS) { - tries += 1 - try { - localDirId = "%s-%04x".format(dateFormat.format(new Date), rand.nextInt(65536)) - localDir = new File(rootDir, s"spark-local-$localDirId") - if (!localDir.exists) { - foundLocalDir = localDir.mkdirs() - } - } catch { - case e: Exception => - logWarning(s"Attempt $tries to create local dir $localDir failed", e) - } - } - if (!foundLocalDir) { - logError(s"Failed $MAX_DIR_CREATION_ATTEMPTS attempts to create local dir in $rootDir." + - " Ignoring this directory.") - None - } else { + try { + val localDir = Utils.createDirectory(rootDir, "blockmgr") logInfo(s"Created local directory at $localDir") Some(localDir) + } catch { + case e: IOException => + logError(s"Failed to create local dir in $rootDir. Ignoring this directory.", e) + None } } } - private def addShutdownHook() { - Runtime.getRuntime.addShutdownHook(new Thread("delete Spark local dirs") { + private def addShutdownHook(): Thread = { + val shutdownHook = new Thread("delete Spark local dirs") { override def run(): Unit = Utils.logUncaughtExceptions { logDebug("Shutdown hook called") - DiskBlockManager.this.stop() + DiskBlockManager.this.doStop() } - }) + } + Runtime.getRuntime.addShutdownHook(shutdownHook) + shutdownHook } /** Cleanup local dirs and stop shuffle sender. */ private[spark] def stop() { + // Remove the shutdown hook. It causes memory leaks if we leave it around. + try { + Runtime.getRuntime.removeShutdownHook(shutdownHook) + } catch { + case e: IllegalStateException => None + } + doStop() + } + + private def doStop(): Unit = { // Only perform cleanup if an external service is not serving our shuffle files. if (!blockManager.externalShuffleServiceEnabled || blockManager.blockManagerId.isDriver) { localDirs.foreach { localDir => diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index 2499c11a65b0e..8f28ef49a8a6f 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -156,8 +156,8 @@ final class ShuffleBlockFetcherIterator( // This needs to be released after use. buf.retain() results.put(new SuccessFetchResult(BlockId(blockId), sizeMap(blockId), buf)) - shuffleMetrics.remoteBytesRead += buf.size - shuffleMetrics.remoteBlocksFetched += 1 + shuffleMetrics.incRemoteBytesRead(buf.size) + shuffleMetrics.incRemoteBlocksFetched(1) } logTrace("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime)) } @@ -233,7 +233,8 @@ final class ShuffleBlockFetcherIterator( val blockId = iter.next() try { val buf = blockManager.getBlockData(blockId) - shuffleMetrics.localBlocksFetched += 1 + shuffleMetrics.incLocalBlocksFetched(1) + shuffleMetrics.incLocalBytesRead(buf.size) buf.retain() results.put(new SuccessFetchResult(blockId, 0, buf)) } catch { @@ -277,7 +278,7 @@ final class ShuffleBlockFetcherIterator( currentResult = results.take() val result = currentResult val stopFetchWait = System.currentTimeMillis() - shuffleMetrics.fetchWaitTime += (stopFetchWait - startFetchWait) + shuffleMetrics.incFetchWaitTime(stopFetchWait - startFetchWait) result match { case SuccessFetchResult(_, size, _) => bytesInFlight -= size diff --git a/core/src/main/scala/org/apache/spark/ui/ConsoleProgressBar.scala b/core/src/main/scala/org/apache/spark/ui/ConsoleProgressBar.scala index 27ba9e18237b5..67f572e79314d 100644 --- a/core/src/main/scala/org/apache/spark/ui/ConsoleProgressBar.scala +++ b/core/src/main/scala/org/apache/spark/ui/ConsoleProgressBar.scala @@ -28,7 +28,6 @@ import org.apache.spark._ * of them will be combined together, showed in one line. */ private[spark] class ConsoleProgressBar(sc: SparkContext) extends Logging { - // Carrige return val CR = '\r' // Update period of progress bar, in milliseconds @@ -121,4 +120,10 @@ private[spark] class ConsoleProgressBar(sc: SparkContext) extends Logging { clear() lastFinishTime = System.currentTimeMillis() } + + /** + * Tear down the timer thread. The timer thread is a GC root, and it retains the entire + * SparkContext if it's not terminated. + */ + def stop(): Unit = timer.cancel() } diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala index 88fed833f922d..bf4b24e98b134 100644 --- a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala @@ -62,17 +62,22 @@ private[spark] object JettyUtils extends Logging { securityMgr: SecurityManager): HttpServlet = { new HttpServlet { override def doGet(request: HttpServletRequest, response: HttpServletResponse) { - if (securityMgr.checkUIViewPermissions(request.getRemoteUser)) { - response.setContentType("%s;charset=utf-8".format(servletParams.contentType)) - response.setStatus(HttpServletResponse.SC_OK) - val result = servletParams.responder(request) - response.setHeader("Cache-Control", "no-cache, no-store, must-revalidate") - response.getWriter.println(servletParams.extractFn(result)) - } else { - response.setStatus(HttpServletResponse.SC_UNAUTHORIZED) - response.setHeader("Cache-Control", "no-cache, no-store, must-revalidate") - response.sendError(HttpServletResponse.SC_UNAUTHORIZED, - "User is not authorized to access this page.") + try { + if (securityMgr.checkUIViewPermissions(request.getRemoteUser)) { + response.setContentType("%s;charset=utf-8".format(servletParams.contentType)) + response.setStatus(HttpServletResponse.SC_OK) + val result = servletParams.responder(request) + response.setHeader("Cache-Control", "no-cache, no-store, must-revalidate") + response.getWriter.println(servletParams.extractFn(result)) + } else { + response.setStatus(HttpServletResponse.SC_UNAUTHORIZED) + response.setHeader("Cache-Control", "no-cache, no-store, must-revalidate") + response.sendError(HttpServletResponse.SC_UNAUTHORIZED, + "User is not authorized to access this page.") + } + } catch { + case e: IllegalArgumentException => + response.sendError(HttpServletResponse.SC_BAD_REQUEST, e.getMessage) } } } diff --git a/core/src/main/scala/org/apache/spark/ui/ToolTips.scala b/core/src/main/scala/org/apache/spark/ui/ToolTips.scala index 6f446c5a95a0a..cae6870c2ab20 100644 --- a/core/src/main/scala/org/apache/spark/ui/ToolTips.scala +++ b/core/src/main/scala/org/apache/spark/ui/ToolTips.scala @@ -24,18 +24,25 @@ private[spark] object ToolTips { scheduler delay is large, consider decreasing the size of tasks or decreasing the size of task results.""" - val TASK_DESERIALIZATION_TIME = - """Time spent deserializating the task closure on the executor.""" + val TASK_DESERIALIZATION_TIME = "Time spent deserializing the task closure on the executor." - val INPUT = "Bytes read from Hadoop or from Spark storage." + val SHUFFLE_READ_BLOCKED_TIME = + "Time that the task spent blocked waiting for shuffle data to be read from remote machines." - val OUTPUT = "Bytes written to Hadoop." + val INPUT = "Bytes and records read from Hadoop or from Spark storage." - val SHUFFLE_WRITE = "Bytes written to disk in order to be read by a shuffle in a future stage." + val OUTPUT = "Bytes and records written to Hadoop." + + val SHUFFLE_WRITE = + "Bytes and records written to disk in order to be read by a shuffle in a future stage." val SHUFFLE_READ = - """Bytes read from remote executors. Typically less than shuffle write bytes - because this does not include shuffle data read locally.""" + """Total shuffle bytes and records read (includes both data read locally and data read from + remote executors). """ + + val SHUFFLE_READ_REMOTE_SIZE = + """Total shuffle bytes read from remote executors. This is a subset of the shuffle + read bytes; the remaining shuffle data is read locally. """ val GETTING_RESULT_TIME = """Time that the driver spends fetching task results from workers. If this is large, consider diff --git a/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala b/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala index b4677447c8872..fc1844600f1cb 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala @@ -22,20 +22,23 @@ import scala.util.Random import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.scheduler.SchedulingMode +// scalastyle:off /** * Continuously generates jobs that expose various features of the WebUI (internal testing tool). * - * Usage: ./bin/spark-class org.apache.spark.ui.UIWorkloadGenerator [master] [FIFO|FAIR] + * Usage: ./bin/spark-class org.apache.spark.ui.UIWorkloadGenerator [master] [FIFO|FAIR] [#job set (4 jobs per set)] */ +// scalastyle:on private[spark] object UIWorkloadGenerator { val NUM_PARTITIONS = 100 val INTER_JOB_WAIT_MS = 5000 def main(args: Array[String]) { - if (args.length < 2) { + if (args.length < 3) { println( - "usage: ./bin/spark-class org.apache.spark.ui.UIWorkloadGenerator [master] [FIFO|FAIR]") + "usage: ./bin/spark-class org.apache.spark.ui.UIWorkloadGenerator " + + "[master] [FIFO|FAIR] [#job set (4 jobs per set)]") System.exit(1) } @@ -45,6 +48,7 @@ private[spark] object UIWorkloadGenerator { if (schedulingMode == SchedulingMode.FAIR) { conf.set("spark.scheduler.mode", "FAIR") } + val nJobSet = args(2).toInt val sc = new SparkContext(conf) def setProperties(s: String) = { @@ -84,7 +88,7 @@ private[spark] object UIWorkloadGenerator { ("Job with delays", baseData.map(x => Thread.sleep(100)).count) ) - while (true) { + (1 to nJobSet).foreach { _ => for ((desc, job) <- jobs) { new Thread { override def run() { @@ -101,5 +105,6 @@ private[spark] object UIWorkloadGenerator { Thread.sleep(INTER_JOB_WAIT_MS) } } + sc.stop() } } diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala index c82730f524eb7..f0ae95bb8c812 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala @@ -43,7 +43,7 @@ private[ui] class ExecutorThreadDumpPage(parent: ExecutorsTab) extends WebUIPage } id }.getOrElse { - return Text(s"Missing executorId parameter") + throw new IllegalArgumentException(s"Missing executorId parameter") } val time = System.currentTimeMillis() val maybeThreadDump = sc.get.getExecutorThreadDump(executorId) diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala index 363cb96de7998..956608d7c0cbe 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala @@ -26,7 +26,8 @@ import org.apache.spark.ui.{ToolTips, UIUtils, WebUIPage} import org.apache.spark.util.Utils /** Summary information about an executor to display in the UI. */ -private case class ExecutorSummaryInfo( +// Needs to be private[ui] because of a false positive MiMa failure. +private[ui] case class ExecutorSummaryInfo( id: String, hostPort: String, rddBlocks: Int, @@ -40,7 +41,8 @@ private case class ExecutorSummaryInfo( totalInputBytes: Long, totalShuffleRead: Long, totalShuffleWrite: Long, - maxMemory: Long) + maxMemory: Long, + executorLogs: Map[String, String]) private[ui] class ExecutorsPage( parent: ExecutorsTab, @@ -55,6 +57,7 @@ private[ui] class ExecutorsPage( val diskUsed = storageStatusList.map(_.diskUsed).sum val execInfo = for (statusId <- 0 until storageStatusList.size) yield getExecInfo(statusId) val execInfoSorted = execInfo.sortBy(_.id) + val logsExist = execInfo.filter(_.executorLogs.nonEmpty).nonEmpty val execTable = @@ -79,10 +82,11 @@ private[ui] class ExecutorsPage( Shuffle Write + {if (logsExist) else Seq.empty} {if (threadDumpEnabled) else Seq.empty} - {execInfoSorted.map(execRow)} + {execInfoSorted.map(execRow(_, logsExist))}
    LogsThread Dump
    @@ -107,7 +111,7 @@ private[ui] class ExecutorsPage( } /** Render an HTML row representing an executor */ - private def execRow(info: ExecutorSummaryInfo): Seq[Node] = { + private def execRow(info: ExecutorSummaryInfo, logsExist: Boolean): Seq[Node] = { val maximumMemory = info.maxMemory val memoryUsed = info.memoryUsed val diskUsed = info.diskUsed @@ -138,6 +142,21 @@ private[ui] class ExecutorsPage( {Utils.bytesToString(info.totalShuffleWrite)} + { + if (logsExist) { + + { + info.executorLogs.map { case (logName, logUrl) => + + } + } + + } + } { if (threadDumpEnabled) { val encodedId = URLEncoder.encode(info.id, "UTF-8") @@ -168,6 +187,7 @@ private[ui] class ExecutorsPage( val totalInputBytes = listener.executorToInputBytes.getOrElse(execId, 0L) val totalShuffleRead = listener.executorToShuffleRead.getOrElse(execId, 0L) val totalShuffleWrite = listener.executorToShuffleWrite.getOrElse(execId, 0L) + val executorLogs = listener.executorToLogUrls.getOrElse(execId, Map.empty) new ExecutorSummaryInfo( execId, @@ -183,7 +203,8 @@ private[ui] class ExecutorsPage( totalInputBytes, totalShuffleRead, totalShuffleWrite, - maxMem + maxMem, + executorLogs ) } } diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala index dd1c2b78c4094..3afd7ef07d7c9 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala @@ -48,12 +48,20 @@ class ExecutorsListener(storageStatusListener: StorageStatusListener) extends Sp val executorToTasksFailed = HashMap[String, Int]() val executorToDuration = HashMap[String, Long]() val executorToInputBytes = HashMap[String, Long]() + val executorToInputRecords = HashMap[String, Long]() val executorToOutputBytes = HashMap[String, Long]() + val executorToOutputRecords = HashMap[String, Long]() val executorToShuffleRead = HashMap[String, Long]() val executorToShuffleWrite = HashMap[String, Long]() + val executorToLogUrls = HashMap[String, Map[String, String]]() def storageStatusList = storageStatusListener.storageStatusList + override def onExecutorAdded(executorAdded: SparkListenerExecutorAdded) = synchronized { + val eid = executorAdded.executorId + executorToLogUrls(eid) = executorAdded.executorInfo.logUrlMap + } + override def onTaskStart(taskStart: SparkListenerTaskStart) = synchronized { val eid = taskStart.taskInfo.executorId executorToTasksActive(eid) = executorToTasksActive.getOrElse(eid, 0) + 1 @@ -78,10 +86,14 @@ class ExecutorsListener(storageStatusListener: StorageStatusListener) extends Sp metrics.inputMetrics.foreach { inputMetrics => executorToInputBytes(eid) = executorToInputBytes.getOrElse(eid, 0L) + inputMetrics.bytesRead + executorToInputRecords(eid) = + executorToInputRecords.getOrElse(eid, 0L) + inputMetrics.recordsRead } metrics.outputMetrics.foreach { outputMetrics => executorToOutputBytes(eid) = executorToOutputBytes.getOrElse(eid, 0L) + outputMetrics.bytesWritten + executorToOutputRecords(eid) = + executorToOutputRecords.getOrElse(eid, 0L) + outputMetrics.recordsWritten } metrics.shuffleReadMetrics.foreach { shuffleRead => executorToShuffleRead(eid) = diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala index 1d1c701878447..bd923d78a86ce 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala @@ -21,7 +21,6 @@ import scala.xml.{Node, NodeSeq} import javax.servlet.http.HttpServletRequest -import org.apache.spark.JobExecutionStatus import org.apache.spark.ui.{WebUIPage, UIUtils} import org.apache.spark.ui.jobs.UIData.JobUIData @@ -43,7 +42,9 @@ private[ui] class AllJobsPage(parent: JobsTab) extends WebUIPage("") { } def makeRow(job: JobUIData): Seq[Node] = { - val lastStageInfo = listener.stageIdToInfo.get(job.stageIds.max) + val lastStageInfo = Option(job.stageIds) + .filter(_.nonEmpty) + .flatMap { ids => listener.stageIdToInfo.get(ids.max) } val lastStageData = lastStageInfo.flatMap { s => listener.stageIdToData.get((s.stageId, s.attemptId)) } @@ -51,13 +52,13 @@ private[ui] class AllJobsPage(parent: JobsTab) extends WebUIPage("") { val lastStageName = lastStageInfo.map(_.name).getOrElse("(Unknown Stage Name)") val lastStageDescription = lastStageData.flatMap(_.description).getOrElse("") val duration: Option[Long] = { - job.startTime.map { start => - val end = job.endTime.getOrElse(System.currentTimeMillis()) + job.submissionTime.map { start => + val end = job.completionTime.getOrElse(System.currentTimeMillis()) end - start } } val formattedDuration = duration.map(d => UIUtils.formatDuration(d)).getOrElse("Unknown") - val formattedSubmissionTime = job.startTime.map(UIUtils.formatDate).getOrElse("Unknown") + val formattedSubmissionTime = job.submissionTime.map(UIUtils.formatDate).getOrElse("Unknown") val detailUrl = "%s/jobs/job?id=%s".format(UIUtils.prependBaseUri(parent.basePath), job.jobId) @@ -65,10 +66,10 @@ private[ui] class AllJobsPage(parent: JobsTab) extends WebUIPage("") { {job.jobId} {job.jobGroup.map(id => s"($id)").getOrElse("")} -
    {lastStageDescription}
    + {lastStageDescription} {lastStageName} - + {formattedSubmissionTime} {formattedDuration} @@ -101,11 +102,11 @@ private[ui] class AllJobsPage(parent: JobsTab) extends WebUIPage("") { val now = System.currentTimeMillis val activeJobsTable = - jobsTable(activeJobs.sortBy(_.startTime.getOrElse(-1L)).reverse) + jobsTable(activeJobs.sortBy(_.submissionTime.getOrElse(-1L)).reverse) val completedJobsTable = - jobsTable(completedJobs.sortBy(_.endTime.getOrElse(-1L)).reverse) + jobsTable(completedJobs.sortBy(_.completionTime.getOrElse(-1L)).reverse) val failedJobsTable = - jobsTable(failedJobs.sortBy(_.endTime.getOrElse(-1L)).reverse) + jobsTable(failedJobs.sortBy(_.completionTime.getOrElse(-1L)).reverse) val shouldShowActiveJobs = activeJobs.nonEmpty val shouldShowCompletedJobs = completedJobs.nonEmpty diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala index b0f8ca2ab0d3f..527f960af2dfc 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala @@ -33,6 +33,7 @@ private[ui] class AllStagesPage(parent: StagesTab) extends WebUIPage("") { def render(request: HttpServletRequest): Seq[Node] = { listener.synchronized { val activeStages = listener.activeStages.values.toSeq + val pendingStages = listener.pendingStages.values.toSeq val completedStages = listener.completedStages.reverse.toSeq val numCompletedStages = listener.numCompletedStages val failedStages = listener.failedStages.reverse.toSeq @@ -43,6 +44,10 @@ private[ui] class AllStagesPage(parent: StagesTab) extends WebUIPage("") { new StageTableBase(activeStages.sortBy(_.submissionTime).reverse, parent.basePath, parent.listener, isFairScheduler = parent.isFairScheduler, killEnabled = parent.killEnabled) + val pendingStagesTable = + new StageTableBase(pendingStages.sortBy(_.submissionTime).reverse, + parent.basePath, parent.listener, isFairScheduler = parent.isFairScheduler, + killEnabled = false) val completedStagesTable = new StageTableBase(completedStages.sortBy(_.submissionTime).reverse, parent.basePath, parent.listener, isFairScheduler = parent.isFairScheduler, killEnabled = false) @@ -54,48 +59,86 @@ private[ui] class AllStagesPage(parent: StagesTab) extends WebUIPage("") { val pools = sc.map(_.getAllPools).getOrElse(Seq.empty[Schedulable]) val poolTable = new PoolTable(pools, parent) + val shouldShowActiveStages = activeStages.nonEmpty + val shouldShowPendingStages = pendingStages.nonEmpty + val shouldShowCompletedStages = completedStages.nonEmpty + val shouldShowFailedStages = failedStages.nonEmpty + val summary: NodeSeq =
      - {if (sc.isDefined) { - // Total duration is not meaningful unless the UI is live -
    • - Total Duration: - {UIUtils.formatDuration(now - sc.get.startTime)} -
    • - }} + { + if (sc.isDefined) { + // Total duration is not meaningful unless the UI is live +
    • + Total Duration: + {UIUtils.formatDuration(now - sc.get.startTime)} +
    • + } + }
    • Scheduling Mode: {listener.schedulingMode.map(_.toString).getOrElse("Unknown")}
    • -
    • - Active Stages: - {activeStages.size} -
    • -
    • - Completed Stages: - {numCompletedStages} -
    • -
    • - Failed Stages: - {numFailedStages} -
    • + { + if (shouldShowActiveStages) { +
    • + Active Stages: + {activeStages.size} +
    • + } + } + { + if (shouldShowPendingStages) { +
    • + Pending Stages: + {pendingStages.size} +
    • + } + } + { + if (shouldShowCompletedStages) { +
    • + Completed Stages: + {numCompletedStages} +
    • + } + } + { + if (shouldShowFailedStages) { +
    • + Failed Stages: + {numFailedStages} +
    • + } + }
    - val content = summary ++ - {if (sc.isDefined && isFairScheduler) { -

    {pools.size} Fair Scheduler Pools

    ++ poolTable.toNodeSeq - } else { - Seq[Node]() - }} ++ -

    Active Stages ({activeStages.size})

    ++ - activeStagesTable.toNodeSeq ++ -

    Completed Stages ({numCompletedStages})

    ++ - completedStagesTable.toNodeSeq ++ -

    Failed Stages ({numFailedStages})

    ++ + var content = summary ++ + { + if (sc.isDefined && isFairScheduler) { +

    {pools.size} Fair Scheduler Pools

    ++ poolTable.toNodeSeq + } else { + Seq[Node]() + } + } + if (shouldShowActiveStages) { + content ++=

    Active Stages ({activeStages.size})

    ++ + activeStagesTable.toNodeSeq + } + if (shouldShowPendingStages) { + content ++=

    Pending Stages ({pendingStages.size})

    ++ + pendingStagesTable.toNodeSeq + } + if (shouldShowCompletedStages) { + content ++=

    Completed Stages ({numCompletedStages})

    ++ + completedStagesTable.toNodeSeq + } + if (shouldShowFailedStages) { + content ++=

    Failed Stages ({numFailedStages})

    ++ failedStagesTable.toNodeSeq - + } UIUtils.headerSparkPage("Spark Stages (for all jobs)", content, parent) } } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala index 9836d11a6d85f..1f8536d1b7195 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala @@ -36,6 +36,20 @@ private[ui] class ExecutorTable(stageId: Int, stageAttemptId: Int, parent: Stage /** Special table which merges two header cells. */ private def executorTable[T](): Seq[Node] = { + val stageData = listener.stageIdToData.get((stageId, stageAttemptId)) + var hasInput = false + var hasOutput = false + var hasShuffleWrite = false + var hasShuffleRead = false + var hasBytesSpilled = false + stageData.foreach(data => { + hasInput = data.hasInput + hasOutput = data.hasOutput + hasShuffleRead = data.hasShuffleRead + hasShuffleWrite = data.hasShuffleWrite + hasBytesSpilled = data.hasBytesSpilled + }) + @@ -44,12 +58,32 @@ private[ui] class ExecutorTable(stageId: Int, stageAttemptId: Int, parent: Stage - - - - - - + {if (hasInput) { + + }} + {if (hasOutput) { + + }} + {if (hasShuffleRead) { + + }} + {if (hasShuffleWrite) { + + }} + {if (hasBytesSpilled) { + + + }} {createExecutorTable()} @@ -76,18 +110,34 @@ private[ui] class ExecutorTable(stageId: Int, stageAttemptId: Int, parent: Stage - - - - - - + {if (stageData.hasInput) { + + }} + {if (stageData.hasOutput) { + + }} + {if (stageData.hasShuffleRead) { + + }} + {if (stageData.hasShuffleWrite) { + + }} + {if (stageData.hasBytesSpilled) { + + + }} } case None => diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala index 77d36209c6048..7541d3e9c72e7 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala @@ -32,7 +32,10 @@ private[ui] class JobPage(parent: JobsTab) extends WebUIPage("job") { def render(request: HttpServletRequest): Seq[Node] = { listener.synchronized { - val jobId = request.getParameter("id").toInt + val parameterId = request.getParameter("id") + require(parameterId != null && parameterId.nonEmpty, "Missing id parameter") + + val jobId = parameterId.toInt val jobDataOption = listener.jobIdToData.get(jobId) if (jobDataOption.isEmpty) { val content = diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala index 72935beb3a34a..937d95a934b59 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala @@ -56,6 +56,7 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { val jobIdToData = new HashMap[JobId, JobUIData] // Stages: + val pendingStages = new HashMap[StageId, StageInfo] val activeStages = new HashMap[StageId, StageInfo] val completedStages = ListBuffer[StageInfo]() val skippedStages = ListBuffer[StageInfo]() @@ -153,14 +154,14 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { val jobData: JobUIData = new JobUIData( jobId = jobStart.jobId, - startTime = Some(System.currentTimeMillis), - endTime = None, + submissionTime = Option(jobStart.time).filter(_ >= 0), stageIds = jobStart.stageIds, jobGroup = jobGroup, status = JobExecutionStatus.RUNNING) + jobStart.stageInfos.foreach(x => pendingStages(x.stageId) = x) // Compute (a potential underestimate of) the number of tasks that will be run by this job. // This may be an underestimate because the job start event references all of the result - // stages's transitive stage dependencies, but some of these stages might be skipped if their + // stages' transitive stage dependencies, but some of these stages might be skipped if their // output is available from earlier runs. // See https://github.com/apache/spark/pull/3009 for a more extensive discussion. jobData.numTasks = { @@ -186,7 +187,9 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { logWarning(s"Job completed for unknown job ${jobEnd.jobId}") new JobUIData(jobId = jobEnd.jobId) } - jobData.endTime = Some(System.currentTimeMillis()) + jobData.completionTime = Option(jobEnd.time).filter(_ >= 0) + + jobData.stageIds.foreach(pendingStages.remove) jobEnd.jobResult match { case JobSucceeded => completedJobs += jobData @@ -200,6 +203,9 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { for (stageId <- jobData.stageIds) { stageIdToActiveJobIds.get(stageId).foreach { jobsUsingStage => jobsUsingStage.remove(jobEnd.jobId) + if (jobsUsingStage.isEmpty) { + stageIdToActiveJobIds.remove(stageId) + } stageIdToInfo.get(stageId).foreach { stageInfo => if (stageInfo.submissionTime.isEmpty) { // if this stage is pending, it won't complete, so mark it as "skipped": @@ -257,7 +263,7 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted) = synchronized { val stage = stageSubmitted.stageInfo activeStages(stage.stageId) = stage - + pendingStages.remove(stage.stageId) val poolName = Option(stageSubmitted.properties).map { p => p.getProperty("spark.scheduler.pool", DEFAULT_POOL_NAME) }.getOrElse(DEFAULT_POOL_NAME) @@ -309,7 +315,7 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { override def onTaskEnd(taskEnd: SparkListenerTaskEnd) = synchronized { val info = taskEnd.taskInfo // If stage attempt id is -1, it means the DAGScheduler had no idea which attempt this task - // compeletion event is for. Let's just drop it here. This means we might have some speculation + // completion event is for. Let's just drop it here. This means we might have some speculation // tasks on the web ui that's never marked as complete. if (info != null && taskEnd.stageAttemptId != -1) { val stageData = stageIdToData.getOrElseUpdate((taskEnd.stageId, taskEnd.stageAttemptId), { @@ -391,24 +397,48 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { stageData.shuffleWriteBytes += shuffleWriteDelta execSummary.shuffleWrite += shuffleWriteDelta + val shuffleWriteRecordsDelta = + (taskMetrics.shuffleWriteMetrics.map(_.shuffleRecordsWritten).getOrElse(0L) + - oldMetrics.flatMap(_.shuffleWriteMetrics).map(_.shuffleRecordsWritten).getOrElse(0L)) + stageData.shuffleWriteRecords += shuffleWriteRecordsDelta + execSummary.shuffleWriteRecords += shuffleWriteRecordsDelta + val shuffleReadDelta = - (taskMetrics.shuffleReadMetrics.map(_.remoteBytesRead).getOrElse(0L) - - oldMetrics.flatMap(_.shuffleReadMetrics).map(_.remoteBytesRead).getOrElse(0L)) - stageData.shuffleReadBytes += shuffleReadDelta + (taskMetrics.shuffleReadMetrics.map(_.totalBytesRead).getOrElse(0L) + - oldMetrics.flatMap(_.shuffleReadMetrics).map(_.totalBytesRead).getOrElse(0L)) + stageData.shuffleReadTotalBytes += shuffleReadDelta execSummary.shuffleRead += shuffleReadDelta + val shuffleReadRecordsDelta = + (taskMetrics.shuffleReadMetrics.map(_.recordsRead).getOrElse(0L) + - oldMetrics.flatMap(_.shuffleReadMetrics).map(_.recordsRead).getOrElse(0L)) + stageData.shuffleReadRecords += shuffleReadRecordsDelta + execSummary.shuffleReadRecords += shuffleReadRecordsDelta + val inputBytesDelta = (taskMetrics.inputMetrics.map(_.bytesRead).getOrElse(0L) - oldMetrics.flatMap(_.inputMetrics).map(_.bytesRead).getOrElse(0L)) stageData.inputBytes += inputBytesDelta execSummary.inputBytes += inputBytesDelta + val inputRecordsDelta = + (taskMetrics.inputMetrics.map(_.recordsRead).getOrElse(0L) + - oldMetrics.flatMap(_.inputMetrics).map(_.recordsRead).getOrElse(0L)) + stageData.inputRecords += inputRecordsDelta + execSummary.inputRecords += inputRecordsDelta + val outputBytesDelta = (taskMetrics.outputMetrics.map(_.bytesWritten).getOrElse(0L) - oldMetrics.flatMap(_.outputMetrics).map(_.bytesWritten).getOrElse(0L)) stageData.outputBytes += outputBytesDelta execSummary.outputBytes += outputBytesDelta + val outputRecordsDelta = + (taskMetrics.outputMetrics.map(_.recordsWritten).getOrElse(0L) + - oldMetrics.flatMap(_.outputMetrics).map(_.recordsWritten).getOrElse(0L)) + stageData.outputRecords += outputRecordsDelta + execSummary.outputRecords += outputRecordsDelta + val diskSpillDelta = taskMetrics.diskBytesSpilled - oldMetrics.map(_.diskBytesSpilled).getOrElse(0L) stageData.diskBytesSpilled += diskSpillDelta diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala index 5fc6cc7533150..f47cdc935e539 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala @@ -32,6 +32,8 @@ private[ui] class PoolPage(parent: StagesTab) extends WebUIPage("pool") { def render(request: HttpServletRequest): Seq[Node] = { listener.synchronized { val poolName = request.getParameter("poolname") + require(poolName != null && poolName.nonEmpty, "Missing poolname parameter") + val poolToActiveStages = listener.poolToActiveStages val activeStages = poolToActiveStages.get(poolName) match { case Some(s) => s.values.toSeq diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index 09a936c2234c0..d752434ad58ae 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -36,8 +36,14 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { def render(request: HttpServletRequest): Seq[Node] = { listener.synchronized { - val stageId = request.getParameter("id").toInt - val stageAttemptId = request.getParameter("attempt").toInt + val parameterId = request.getParameter("id") + require(parameterId != null && parameterId.nonEmpty, "Missing id parameter") + + val parameterAttempt = request.getParameter("attempt") + require(parameterAttempt != null && parameterAttempt.nonEmpty, "Missing attempt parameter") + + val stageId = parameterId.toInt + val stageAttemptId = parameterAttempt.toInt val stageDataOption = listener.stageIdToData.get((stageId, stageAttemptId)) if (stageDataOption.isEmpty || stageDataOption.get.taskData.isEmpty) { @@ -56,11 +62,6 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { val numCompleted = tasks.count(_.taskInfo.finished) val accumulables = listener.stageIdToData((stageId, stageAttemptId)).accumulables val hasAccumulators = accumulables.size > 0 - val hasInput = stageData.inputBytes > 0 - val hasOutput = stageData.outputBytes > 0 - val hasShuffleRead = stageData.shuffleReadBytes > 0 - val hasShuffleWrite = stageData.shuffleWriteBytes > 0 - val hasBytesSpilled = stageData.memoryBytesSpilled > 0 && stageData.diskBytesSpilled > 0 val summary =
    @@ -69,31 +70,33 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { Total task time across all tasks: {UIUtils.formatDuration(stageData.executorRunTime)} - {if (hasInput) { + {if (stageData.hasInput) {
  • - Input: - {Utils.bytesToString(stageData.inputBytes)} + Input Size / Records: + {s"${Utils.bytesToString(stageData.inputBytes)} / ${stageData.inputRecords}"}
  • }} - {if (hasOutput) { + {if (stageData.hasOutput) {
  • Output: - {Utils.bytesToString(stageData.outputBytes)} + {s"${Utils.bytesToString(stageData.outputBytes)} / ${stageData.outputRecords}"}
  • }} - {if (hasShuffleRead) { + {if (stageData.hasShuffleRead) {
  • Shuffle read: - {Utils.bytesToString(stageData.shuffleReadBytes)} + {s"${Utils.bytesToString(stageData.shuffleReadTotalBytes)} / " + + s"${stageData.shuffleReadRecords}"}
  • }} - {if (hasShuffleWrite) { + {if (stageData.hasShuffleWrite) {
  • Shuffle write: - {Utils.bytesToString(stageData.shuffleWriteBytes)} + {s"${Utils.bytesToString(stageData.shuffleWriteBytes)} / " + + s"${stageData.shuffleWriteRecords}"}
  • }} - {if (hasBytesSpilled) { + {if (stageData.hasBytesSpilled) {
  • Shuffle spill (memory): {Utils.bytesToString(stageData.memoryBytesSpilled)} @@ -132,6 +135,22 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { Task Deserialization Time
  • + {if (stageData.hasShuffleRead) { +
  • + + + Shuffle Read Blocked Time + +
  • +
  • + + + Shuffle Remote Reads + +
  • + }}
  • @@ -165,20 +184,33 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { ("Result Serialization Time", TaskDetailsClassNames.RESULT_SERIALIZATION_TIME), ("Getting Result Time", TaskDetailsClassNames.GETTING_RESULT_TIME)) ++ {if (hasAccumulators) Seq(("Accumulators", "")) else Nil} ++ - {if (hasInput) Seq(("Input", "")) else Nil} ++ - {if (hasOutput) Seq(("Output", "")) else Nil} ++ - {if (hasShuffleRead) Seq(("Shuffle Read", "")) else Nil} ++ - {if (hasShuffleWrite) Seq(("Write Time", ""), ("Shuffle Write", "")) else Nil} ++ - {if (hasBytesSpilled) Seq(("Shuffle Spill (Memory)", ""), ("Shuffle Spill (Disk)", "")) - else Nil} ++ + {if (stageData.hasInput) Seq(("Input Size / Records", "")) else Nil} ++ + {if (stageData.hasOutput) Seq(("Output Size / Records", "")) else Nil} ++ + {if (stageData.hasShuffleRead) { + Seq(("Shuffle Read Blocked Time", TaskDetailsClassNames.SHUFFLE_READ_BLOCKED_TIME), + ("Shuffle Read Size / Records", ""), + ("Shuffle Remote Reads", TaskDetailsClassNames.SHUFFLE_READ_REMOTE_SIZE)) + } else { + Nil + }} ++ + {if (stageData.hasShuffleWrite) { + Seq(("Write Time", ""), ("Shuffle Write Size / Records", "")) + } else { + Nil + }} ++ + {if (stageData.hasBytesSpilled) { + Seq(("Shuffle Spill (Memory)", ""), ("Shuffle Spill (Disk)", "")) + } else { + Nil + }} ++ Seq(("Errors", "")) val unzipped = taskHeadersAndCssClasses.unzip val taskTable = UIUtils.listingTable( unzipped._1, - taskRow(hasAccumulators, hasInput, hasOutput, hasShuffleRead, hasShuffleWrite, - hasBytesSpilled), + taskRow(hasAccumulators, stageData.hasInput, stageData.hasOutput, + stageData.hasShuffleRead, stageData.hasShuffleWrite, stageData.hasBytesSpilled), tasks, headerClasses = unzipped._2) // Excludes tasks which failed and have incomplete metrics @@ -189,8 +221,11 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { None } else { + def getDistributionQuantiles(data: Seq[Double]): IndexedSeq[Double] = + Distribution(data).get.getQuantiles() + def getFormattedTimeQuantiles(times: Seq[Double]): Seq[Node] = { - Distribution(times).get.getQuantiles().map { millis => + getDistributionQuantiles(times).map { millis =>
  • } } @@ -259,29 +294,86 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { getFormattedTimeQuantiles(schedulerDelays) def getFormattedSizeQuantiles(data: Seq[Double]) = - Distribution(data).get.getQuantiles().map(d => ) + getDistributionQuantiles(data).map(d => ) + + def getFormattedSizeQuantilesWithRecords(data: Seq[Double], records: Seq[Double]) = { + val recordDist = getDistributionQuantiles(records).iterator + getDistributionQuantiles(data).map(d => + + ) + } val inputSizes = validTasks.map { case TaskUIData(_, metrics, _) => metrics.get.inputMetrics.map(_.bytesRead).getOrElse(0L).toDouble } - val inputQuantiles = +: getFormattedSizeQuantiles(inputSizes) + + val inputRecords = validTasks.map { case TaskUIData(_, metrics, _) => + metrics.get.inputMetrics.map(_.recordsRead).getOrElse(0L).toDouble + } + + val inputQuantiles = +: + getFormattedSizeQuantilesWithRecords(inputSizes, inputRecords) val outputSizes = validTasks.map { case TaskUIData(_, metrics, _) => metrics.get.outputMetrics.map(_.bytesWritten).getOrElse(0L).toDouble } - val outputQuantiles = +: getFormattedSizeQuantiles(outputSizes) - val shuffleReadSizes = validTasks.map { case TaskUIData(_, metrics, _) => + val outputRecords = validTasks.map { case TaskUIData(_, metrics, _) => + metrics.get.outputMetrics.map(_.recordsWritten).getOrElse(0L).toDouble + } + + val outputQuantiles = +: + getFormattedSizeQuantilesWithRecords(outputSizes, outputRecords) + + val shuffleReadBlockedTimes = validTasks.map { case TaskUIData(_, metrics, _) => + metrics.get.shuffleReadMetrics.map(_.fetchWaitTime).getOrElse(0L).toDouble + } + val shuffleReadBlockedQuantiles = + +: + getFormattedTimeQuantiles(shuffleReadBlockedTimes) + + val shuffleReadTotalSizes = validTasks.map { case TaskUIData(_, metrics, _) => + metrics.get.shuffleReadMetrics.map(_.totalBytesRead).getOrElse(0L).toDouble + } + val shuffleReadTotalRecords = validTasks.map { case TaskUIData(_, metrics, _) => + metrics.get.shuffleReadMetrics.map(_.recordsRead).getOrElse(0L).toDouble + } + val shuffleReadTotalQuantiles = + +: + getFormattedSizeQuantilesWithRecords(shuffleReadTotalSizes, shuffleReadTotalRecords) + + val shuffleReadRemoteSizes = validTasks.map { case TaskUIData(_, metrics, _) => metrics.get.shuffleReadMetrics.map(_.remoteBytesRead).getOrElse(0L).toDouble } - val shuffleReadQuantiles = +: - getFormattedSizeQuantiles(shuffleReadSizes) + val shuffleReadRemoteQuantiles = + +: + getFormattedSizeQuantiles(shuffleReadRemoteSizes) val shuffleWriteSizes = validTasks.map { case TaskUIData(_, metrics, _) => metrics.get.shuffleWriteMetrics.map(_.shuffleBytesWritten).getOrElse(0L).toDouble } - val shuffleWriteQuantiles = +: - getFormattedSizeQuantiles(shuffleWriteSizes) + + val shuffleWriteRecords = validTasks.map { case TaskUIData(_, metrics, _) => + metrics.get.shuffleWriteMetrics.map(_.shuffleRecordsWritten).getOrElse(0L).toDouble + } + + val shuffleWriteQuantiles = +: + getFormattedSizeQuantilesWithRecords(shuffleWriteSizes, shuffleWriteRecords) val memoryBytesSpilledSizes = validTasks.map { case TaskUIData(_, metrics, _) => metrics.get.memoryBytesSpilled.toDouble @@ -306,12 +398,22 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { {serializationQuantiles} , {gettingResultQuantiles}, - if (hasInput) {inputQuantiles} else Nil, - if (hasOutput) {outputQuantiles} else Nil, - if (hasShuffleRead) {shuffleReadQuantiles} else Nil, - if (hasShuffleWrite) {shuffleWriteQuantiles} else Nil, - if (hasBytesSpilled) {memoryBytesSpilledQuantiles} else Nil, - if (hasBytesSpilled) {diskBytesSpilledQuantiles} else Nil) + if (stageData.hasInput) {inputQuantiles} else Nil, + if (stageData.hasOutput) {outputQuantiles} else Nil, + if (stageData.hasShuffleRead) { + + {shuffleReadBlockedQuantiles} + + {shuffleReadTotalQuantiles} + + {shuffleReadRemoteQuantiles} + + } else { + Nil + }, + if (stageData.hasShuffleWrite) {shuffleWriteQuantiles} else Nil, + if (stageData.hasBytesSpilled) {memoryBytesSpilledQuantiles} else Nil, + if (stageData.hasBytesSpilled) {diskBytesSpilledQuantiles} else Nil) val quantileHeaders = Seq("Metric", "Min", "25th percentile", "Median", "75th percentile", "Max") @@ -370,21 +472,36 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { val inputReadable = maybeInput .map(m => s"${Utils.bytesToString(m.bytesRead)} (${m.readMethod.toString.toLowerCase()})") .getOrElse("") + val inputRecords = maybeInput.map(_.recordsRead.toString).getOrElse("") val maybeOutput = metrics.flatMap(_.outputMetrics) val outputSortable = maybeOutput.map(_.bytesWritten.toString).getOrElse("") val outputReadable = maybeOutput .map(m => s"${Utils.bytesToString(m.bytesWritten)}") .getOrElse("") - - val maybeShuffleRead = metrics.flatMap(_.shuffleReadMetrics).map(_.remoteBytesRead) - val shuffleReadSortable = maybeShuffleRead.map(_.toString).getOrElse("") - val shuffleReadReadable = maybeShuffleRead.map(Utils.bytesToString).getOrElse("") - - val maybeShuffleWrite = - metrics.flatMap(_.shuffleWriteMetrics).map(_.shuffleBytesWritten) - val shuffleWriteSortable = maybeShuffleWrite.map(_.toString).getOrElse("") - val shuffleWriteReadable = maybeShuffleWrite.map(Utils.bytesToString).getOrElse("") + val outputRecords = maybeOutput.map(_.recordsWritten.toString).getOrElse("") + + val maybeShuffleRead = metrics.flatMap(_.shuffleReadMetrics) + val shuffleReadBlockedTimeSortable = maybeShuffleRead + .map(_.fetchWaitTime.toString).getOrElse("") + val shuffleReadBlockedTimeReadable = + maybeShuffleRead.map(ms => UIUtils.formatDuration(ms.fetchWaitTime)).getOrElse("") + + val totalShuffleBytes = maybeShuffleRead.map(_.totalBytesRead) + val shuffleReadSortable = totalShuffleBytes.map(_.toString).getOrElse("") + val shuffleReadReadable = totalShuffleBytes.map(Utils.bytesToString).getOrElse("") + val shuffleReadRecords = maybeShuffleRead.map(_.recordsRead.toString).getOrElse("") + + val remoteShuffleBytes = maybeShuffleRead.map(_.remoteBytesRead) + val shuffleReadRemoteSortable = remoteShuffleBytes.map(_.toString).getOrElse("") + val shuffleReadRemoteReadable = remoteShuffleBytes.map(Utils.bytesToString).getOrElse("") + + val maybeShuffleWrite = metrics.flatMap(_.shuffleWriteMetrics) + val shuffleWriteSortable = maybeShuffleWrite.map(_.shuffleBytesWritten.toString).getOrElse("") + val shuffleWriteReadable = maybeShuffleWrite + .map(m => s"${Utils.bytesToString(m.shuffleBytesWritten)}").getOrElse("") + val shuffleWriteRecords = maybeShuffleWrite + .map(_.shuffleRecordsWritten.toString).getOrElse("") val maybeWriteTime = metrics.flatMap(_.shuffleWriteMetrics).map(_.shuffleWriteTime) val writeTimeSortable = maybeWriteTime.map(_.toString).getOrElse("") @@ -440,17 +557,25 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { }} {if (hasInput) { }} {if (hasOutput) { }} {if (hasShuffleRead) { + + }} {if (hasShuffleWrite) { @@ -458,7 +583,7 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { {writeTimeReadable} }} {if (hasBytesSpilled) { diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala index e7d6244dcd679..5865850fa09b5 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala @@ -112,9 +112,8 @@ private[ui] class StageTableBase( stageData <- listener.stageIdToData.get((s.stageId, s.attemptId)) desc <- stageData.description } yield { -
    {desc}
    + {desc} } -
    {stageDesc.getOrElse("")} {killLink} {nameLink} {details}
    } @@ -139,7 +138,7 @@ private[ui] class StageTableBase( val inputReadWithUnit = if (inputRead > 0) Utils.bytesToString(inputRead) else "" val outputWrite = stageData.outputBytes val outputWriteWithUnit = if (outputWrite > 0) Utils.bytesToString(outputWrite) else "" - val shuffleRead = stageData.shuffleReadBytes + val shuffleRead = stageData.shuffleReadTotalBytes val shuffleReadWithUnit = if (shuffleRead > 0) Utils.bytesToString(shuffleRead) else "" val shuffleWrite = stageData.shuffleWriteBytes val shuffleWriteWithUnit = if (shuffleWrite > 0) Utils.bytesToString(shuffleWrite) else "" diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/TaskDetailsClassNames.scala b/core/src/main/scala/org/apache/spark/ui/jobs/TaskDetailsClassNames.scala index 2d13bb6ddde42..9bf67db8acde1 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/TaskDetailsClassNames.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/TaskDetailsClassNames.scala @@ -27,6 +27,8 @@ package org.apache.spark.ui.jobs private[spark] object TaskDetailsClassNames { val SCHEDULER_DELAY = "scheduler_delay" val TASK_DESERIALIZATION_TIME = "deserialization_time" + val SHUFFLE_READ_BLOCKED_TIME = "fetch_wait_time" + val SHUFFLE_READ_REMOTE_SIZE = "shuffle_read_remote" val RESULT_SERIALIZATION_TIME = "serialization_time" val GETTING_RESULT_TIME = "getting_result_time" } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala b/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala index 48fd7caa1a1ed..dbf1ceeda1878 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala @@ -31,24 +31,28 @@ private[jobs] object UIData { var failedTasks : Int = 0 var succeededTasks : Int = 0 var inputBytes : Long = 0 + var inputRecords : Long = 0 var outputBytes : Long = 0 + var outputRecords : Long = 0 var shuffleRead : Long = 0 + var shuffleReadRecords : Long = 0 var shuffleWrite : Long = 0 + var shuffleWriteRecords : Long = 0 var memoryBytesSpilled : Long = 0 var diskBytesSpilled : Long = 0 } class JobUIData( var jobId: Int = -1, - var startTime: Option[Long] = None, - var endTime: Option[Long] = None, + var submissionTime: Option[Long] = None, + var completionTime: Option[Long] = None, var stageIds: Seq[Int] = Seq.empty, var jobGroup: Option[String] = None, var status: JobExecutionStatus = JobExecutionStatus.UNKNOWN, /* Tasks */ // `numTasks` is a potential underestimate of the true number of tasks that this job will run. // This may be an underestimate because the job start event references all of the result - // stages's transitive stage dependencies, but some of these stages might be skipped if their + // stages' transitive stage dependencies, but some of these stages might be skipped if their // output is available from earlier runs. // See https://github.com/apache/spark/pull/3009 for a more extensive discussion. var numTasks: Int = 0, @@ -73,9 +77,13 @@ private[jobs] object UIData { var executorRunTime: Long = _ var inputBytes: Long = _ + var inputRecords: Long = _ var outputBytes: Long = _ - var shuffleReadBytes: Long = _ + var outputRecords: Long = _ + var shuffleReadTotalBytes: Long = _ + var shuffleReadRecords : Long = _ var shuffleWriteBytes: Long = _ + var shuffleWriteRecords: Long = _ var memoryBytesSpilled: Long = _ var diskBytesSpilled: Long = _ @@ -85,6 +93,12 @@ private[jobs] object UIData { var accumulables = new HashMap[Long, AccumulableInfo] var taskData = new HashMap[Long, TaskUIData] var executorSummary = new HashMap[String, ExecutorSummary] + + def hasInput = inputBytes > 0 + def hasOutput = outputBytes > 0 + def hasShuffleRead = shuffleReadTotalBytes > 0 + def hasShuffleWrite = shuffleWriteBytes > 0 + def hasBytesSpilled = memoryBytesSpilled > 0 && diskBytesSpilled > 0 } /** diff --git a/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala b/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala index 12d23a92878cf..199f731b92bcc 100644 --- a/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala @@ -30,7 +30,10 @@ private[ui] class RDDPage(parent: StorageTab) extends WebUIPage("rdd") { private val listener = parent.listener def render(request: HttpServletRequest): Seq[Node] = { - val rddId = request.getParameter("id").toInt + val parameterId = request.getParameter("id") + require(parameterId != null && parameterId.nonEmpty, "Missing id parameter") + + val rddId = parameterId.toInt val storageStatusList = listener.storageStatusList val rddInfo = listener.rddInfoList.find(_.id == rddId).getOrElse { // Rather than crashing, render an "RDD Not Found" page diff --git a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala index 4c9b1e3c46f0f..48a6ede05e17b 100644 --- a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala @@ -20,6 +20,7 @@ package org.apache.spark.util import scala.collection.JavaConversions.mapAsJavaMap import scala.concurrent.Await import scala.concurrent.duration.{Duration, FiniteDuration} +import scala.util.Try import akka.actor.{ActorRef, ActorSystem, ExtendedActorSystem} import akka.pattern.ask @@ -78,8 +79,6 @@ private[spark] object AkkaUtils extends Logging { val logAkkaConfig = if (conf.getBoolean("spark.akka.logAkkaConfig", false)) "on" else "off" val akkaHeartBeatPauses = conf.getInt("spark.akka.heartbeat.pauses", 6000) - val akkaFailureDetector = - conf.getDouble("spark.akka.failure-detector.threshold", 300.0) val akkaHeartBeatInterval = conf.getInt("spark.akka.heartbeat.interval", 1000) val secretKey = securityManager.getSecretKey() @@ -91,8 +90,11 @@ private[spark] object AkkaUtils extends Logging { val secureCookie = if (isAuthOn) secretKey else "" logDebug(s"In createActorSystem, requireCookie is: $requireCookie") - val akkaConf = ConfigFactory.parseMap(conf.getAkkaConf.toMap[String, String]).withFallback( - ConfigFactory.parseString( + val akkaSslConfig = securityManager.akkaSSLOptions.createAkkaConfig + .getOrElse(ConfigFactory.empty()) + + val akkaConf = ConfigFactory.parseMap(conf.getAkkaConf.toMap[String, String]) + .withFallback(akkaSslConfig).withFallback(ConfigFactory.parseString( s""" |akka.daemonic = on |akka.loggers = [""akka.event.slf4j.Slf4jLogger""] @@ -102,7 +104,6 @@ private[spark] object AkkaUtils extends Logging { |akka.remote.secure-cookie = "$secureCookie" |akka.remote.transport-failure-detector.heartbeat-interval = $akkaHeartBeatInterval s |akka.remote.transport-failure-detector.acceptable-heartbeat-pause = $akkaHeartBeatPauses s - |akka.remote.transport-failure-detector.threshold = $akkaFailureDetector |akka.actor.provider = "akka.remote.RemoteActorRefProvider" |akka.remote.netty.tcp.transport-class = "akka.remote.transport.netty.NettyTransport" |akka.remote.netty.tcp.hostname = "$host" @@ -214,7 +215,7 @@ private[spark] object AkkaUtils extends Logging { val driverHost: String = conf.get("spark.driver.host", "localhost") val driverPort: Int = conf.getInt("spark.driver.port", 7077) Utils.checkHost(driverHost, "Expected hostname") - val url = s"akka.tcp://$driverActorSystemName@$driverHost:$driverPort/user/$name" + val url = address(protocol(actorSystem), driverActorSystemName, driverHost, driverPort, name) val timeout = AkkaUtils.lookupTimeout(conf) logInfo(s"Connecting to $name: $url") Await.result(actorSystem.actorSelection(url).resolveOne(timeout), timeout) @@ -228,9 +229,33 @@ private[spark] object AkkaUtils extends Logging { actorSystem: ActorSystem): ActorRef = { val executorActorSystemName = SparkEnv.executorActorSystemName Utils.checkHost(host, "Expected hostname") - val url = s"akka.tcp://$executorActorSystemName@$host:$port/user/$name" + val url = address(protocol(actorSystem), executorActorSystemName, host, port, name) val timeout = AkkaUtils.lookupTimeout(conf) logInfo(s"Connecting to $name: $url") Await.result(actorSystem.actorSelection(url).resolveOne(timeout), timeout) } + + def protocol(actorSystem: ActorSystem): String = { + val akkaConf = actorSystem.settings.config + val sslProp = "akka.remote.netty.tcp.enable-ssl" + protocol(akkaConf.hasPath(sslProp) && akkaConf.getBoolean(sslProp)) + } + + def protocol(ssl: Boolean = false): String = { + if (ssl) { + "akka.ssl.tcp" + } else { + "akka.tcp" + } + } + + def address( + protocol: String, + systemName: String, + host: String, + port: Any, + actorName: String): String = { + s"$protocol://$systemName@$host:$port/user/$actorName" + } + } diff --git a/core/src/main/scala/org/apache/spark/util/AsynchronousListenerBus.scala b/core/src/main/scala/org/apache/spark/util/AsynchronousListenerBus.scala new file mode 100644 index 0000000000000..18c627e8c7a15 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/AsynchronousListenerBus.scala @@ -0,0 +1,173 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import java.util.concurrent._ +import java.util.concurrent.atomic.AtomicBoolean + +import com.google.common.annotations.VisibleForTesting + +/** + * Asynchronously passes events to registered listeners. + * + * Until `start()` is called, all posted events are only buffered. Only after this listener bus + * has started will events be actually propagated to all attached listeners. This listener bus + * is stopped when `stop()` is called, and it will drop further events after stopping. + * + * @param name name of the listener bus, will be the name of the listener thread. + * @tparam L type of listener + * @tparam E type of event + */ +private[spark] abstract class AsynchronousListenerBus[L <: AnyRef, E](name: String) + extends ListenerBus[L, E] { + + self => + + /* Cap the capacity of the event queue so we get an explicit error (rather than + * an OOM exception) if it's perpetually being added to more quickly than it's being drained. */ + private val EVENT_QUEUE_CAPACITY = 10000 + private val eventQueue = new LinkedBlockingQueue[E](EVENT_QUEUE_CAPACITY) + + // Indicate if `start()` is called + private val started = new AtomicBoolean(false) + // Indicate if `stop()` is called + private val stopped = new AtomicBoolean(false) + + // Indicate if we are processing some event + // Guarded by `self` + private var processingEvent = false + + // A counter that represents the number of events produced and consumed in the queue + private val eventLock = new Semaphore(0) + + private val listenerThread = new Thread(name) { + setDaemon(true) + override def run(): Unit = Utils.logUncaughtExceptions { + while (true) { + eventLock.acquire() + self.synchronized { + processingEvent = true + } + try { + val event = eventQueue.poll + if (event == null) { + // Get out of the while loop and shutdown the daemon thread + if (!stopped.get) { + throw new IllegalStateException("Polling `null` from eventQueue means" + + " the listener bus has been stopped. So `stopped` must be true") + } + return + } + postToAll(event) + } finally { + self.synchronized { + processingEvent = false + } + } + } + } + } + + /** + * Start sending events to attached listeners. + * + * This first sends out all buffered events posted before this listener bus has started, then + * listens for any additional events asynchronously while the listener bus is still running. + * This should only be called once. + */ + def start() { + if (started.compareAndSet(false, true)) { + listenerThread.start() + } else { + throw new IllegalStateException(s"$name already started!") + } + } + + def post(event: E) { + if (stopped.get) { + // Drop further events to make `listenerThread` exit ASAP + logError(s"$name has already stopped! Dropping event $event") + return + } + val eventAdded = eventQueue.offer(event) + if (eventAdded) { + eventLock.release() + } else { + onDropEvent(event) + } + } + + /** + * For testing only. Wait until there are no more events in the queue, or until the specified + * time has elapsed. Return true if the queue has emptied and false is the specified time + * elapsed before the queue emptied. + */ + @VisibleForTesting + def waitUntilEmpty(timeoutMillis: Int): Boolean = { + val finishTime = System.currentTimeMillis + timeoutMillis + while (!queueIsEmpty) { + if (System.currentTimeMillis > finishTime) { + return false + } + /* Sleep rather than using wait/notify, because this is used only for testing and + * wait/notify add overhead in the general case. */ + Thread.sleep(10) + } + true + } + + /** + * For testing only. Return whether the listener daemon thread is still alive. + */ + @VisibleForTesting + def listenerThreadIsAlive: Boolean = listenerThread.isAlive + + /** + * Return whether the event queue is empty. + * + * The use of synchronized here guarantees that all events that once belonged to this queue + * have already been processed by all attached listeners, if this returns true. + */ + private def queueIsEmpty: Boolean = synchronized { eventQueue.isEmpty && !processingEvent } + + /** + * Stop the listener bus. It will wait until the queued events have been processed, but drop the + * new events after stopping. + */ + def stop() { + if (!started.get()) { + throw new IllegalStateException(s"Attempted to stop $name that has not yet started!") + } + if (stopped.compareAndSet(false, true)) { + // Call eventLock.release() so that listenerThread will poll `null` from `eventQueue` and know + // `stop` is called. + eventLock.release() + listenerThread.join() + } else { + // Keep quiet + } + } + + /** + * If the event queue exceeds its capacity, the new events will be dropped. The subclasses will be + * notified with the dropped events. + * + * Note: `onDropEvent` can be called in any thread. + */ + def onDropEvent(event: E): Unit +} diff --git a/core/src/main/scala/org/apache/spark/util/Clock.scala b/core/src/main/scala/org/apache/spark/util/Clock.scala index 97c2b45aabf28..e92ed11bd165b 100644 --- a/core/src/main/scala/org/apache/spark/util/Clock.scala +++ b/core/src/main/scala/org/apache/spark/util/Clock.scala @@ -21,9 +21,47 @@ package org.apache.spark.util * An interface to represent clocks, so that they can be mocked out in unit tests. */ private[spark] trait Clock { - def getTime(): Long + def getTimeMillis(): Long + def waitTillTime(targetTime: Long): Long } -private[spark] object SystemClock extends Clock { - def getTime(): Long = System.currentTimeMillis() +/** + * A clock backed by the actual time from the OS as reported by the `System` API. + */ +private[spark] class SystemClock extends Clock { + + val minPollTime = 25L + + /** + * @return the same time (milliseconds since the epoch) + * as is reported by `System.currentTimeMillis()` + */ + def getTimeMillis(): Long = System.currentTimeMillis() + + /** + * @param targetTime block until the current time is at least this value + * @return current system time when wait has completed + */ + def waitTillTime(targetTime: Long): Long = { + var currentTime = 0L + currentTime = System.currentTimeMillis() + + var waitTime = targetTime - currentTime + if (waitTime <= 0) { + return currentTime + } + + val pollTime = math.max(waitTime / 10.0, minPollTime).toLong + + while (true) { + currentTime = System.currentTimeMillis() + waitTime = targetTime - currentTime + if (waitTime <= 0) { + return currentTime + } + val sleepTime = math.min(waitTime, pollTime) + Thread.sleep(sleepTime) + } + -1 + } } diff --git a/core/src/main/scala/org/apache/spark/util/EventLoop.scala b/core/src/main/scala/org/apache/spark/util/EventLoop.scala new file mode 100644 index 0000000000000..b0ed908b84424 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/EventLoop.scala @@ -0,0 +1,124 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import java.util.concurrent.atomic.AtomicBoolean +import java.util.concurrent.{BlockingQueue, LinkedBlockingDeque} + +import scala.util.control.NonFatal + +import org.apache.spark.Logging + +/** + * An event loop to receive events from the caller and process all events in the event thread. It + * will start an exclusive event thread to process all events. + * + * Note: The event queue will grow indefinitely. So subclasses should make sure `onReceive` can + * handle events in time to avoid the potential OOM. + */ +private[spark] abstract class EventLoop[E](name: String) extends Logging { + + private val eventQueue: BlockingQueue[E] = new LinkedBlockingDeque[E]() + + private val stopped = new AtomicBoolean(false) + + private val eventThread = new Thread(name) { + setDaemon(true) + + override def run(): Unit = { + try { + while (!stopped.get) { + val event = eventQueue.take() + try { + onReceive(event) + } catch { + case NonFatal(e) => { + try { + onError(e) + } catch { + case NonFatal(e) => logError("Unexpected error in " + name, e) + } + } + } + } + } catch { + case ie: InterruptedException => // exit even if eventQueue is not empty + case NonFatal(e) => logError("Unexpected error in " + name, e) + } + } + + } + + def start(): Unit = { + if (stopped.get) { + throw new IllegalStateException(name + " has already been stopped") + } + // Call onStart before starting the event thread to make sure it happens before onReceive + onStart() + eventThread.start() + } + + def stop(): Unit = { + if (stopped.compareAndSet(false, true)) { + eventThread.interrupt() + eventThread.join() + // Call onStop after the event thread exits to make sure onReceive happens before onStop + onStop() + } else { + // Keep quiet to allow calling `stop` multiple times. + } + } + + /** + * Put the event into the event queue. The event thread will process it later. + */ + def post(event: E): Unit = { + eventQueue.put(event) + } + + /** + * Return if the event thread has already been started but not yet stopped. + */ + def isActive: Boolean = eventThread.isAlive + + /** + * Invoked when `start()` is called but before the event thread starts. + */ + protected def onStart(): Unit = {} + + /** + * Invoked when `stop()` is called and the event thread exits. + */ + protected def onStop(): Unit = {} + + /** + * Invoked in the event thread when polling events from the event queue. + * + * Note: Should avoid calling blocking actions in `onReceive`, or the event thread will be blocked + * and cannot process events in time. If you want to call some blocking actions, run them in + * another thread. + */ + protected def onReceive(event: E): Unit + + /** + * Invoked if `onReceive` throws any non fatal error. Any non fatal error thrown from `onError` + * will be ignored. + */ + protected def onError(e: Throwable): Unit + +} diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index ee3756c226fe3..8e20864db5673 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -90,7 +90,6 @@ private[spark] object JsonProtocol { case executorRemoved: SparkListenerExecutorRemoved => executorRemovedToJson(executorRemoved) // These aren't used, but keeps compiler happy - case SparkListenerShutdown => JNothing case SparkListenerExecutorMetricsUpdate(_, _) => JNothing } } @@ -141,6 +140,7 @@ private[spark] object JsonProtocol { val properties = propertiesToJson(jobStart.properties) ("Event" -> Utils.getFormattedClassName(jobStart)) ~ ("Job ID" -> jobStart.jobId) ~ + ("Submission Time" -> jobStart.time) ~ ("Stage Infos" -> jobStart.stageInfos.map(stageInfoToJson)) ~ // Added in Spark 1.2.0 ("Stage IDs" -> jobStart.stageIds) ~ ("Properties" -> properties) @@ -150,6 +150,7 @@ private[spark] object JsonProtocol { val jobResult = jobResultToJson(jobEnd.jobResult) ("Event" -> Utils.getFormattedClassName(jobEnd)) ~ ("Job ID" -> jobEnd.jobId) ~ + ("Completion Time" -> jobEnd.time) ~ ("Job Result" -> jobResult) } @@ -201,13 +202,16 @@ private[spark] object JsonProtocol { def executorAddedToJson(executorAdded: SparkListenerExecutorAdded): JValue = { ("Event" -> Utils.getFormattedClassName(executorAdded)) ~ + ("Timestamp" -> executorAdded.time) ~ ("Executor ID" -> executorAdded.executorId) ~ ("Executor Info" -> executorInfoToJson(executorAdded.executorInfo)) } def executorRemovedToJson(executorRemoved: SparkListenerExecutorRemoved): JValue = { ("Event" -> Utils.getFormattedClassName(executorRemoved)) ~ - ("Executor ID" -> executorRemoved.executorId) + ("Timestamp" -> executorRemoved.time) ~ + ("Executor ID" -> executorRemoved.executorId) ~ + ("Removed Reason" -> executorRemoved.reason) } /** ------------------------------------------------------------------- * @@ -289,22 +293,27 @@ private[spark] object JsonProtocol { ("Remote Blocks Fetched" -> shuffleReadMetrics.remoteBlocksFetched) ~ ("Local Blocks Fetched" -> shuffleReadMetrics.localBlocksFetched) ~ ("Fetch Wait Time" -> shuffleReadMetrics.fetchWaitTime) ~ - ("Remote Bytes Read" -> shuffleReadMetrics.remoteBytesRead) + ("Remote Bytes Read" -> shuffleReadMetrics.remoteBytesRead) ~ + ("Local Bytes Read" -> shuffleReadMetrics.localBytesRead) ~ + ("Total Records Read" -> shuffleReadMetrics.recordsRead) } def shuffleWriteMetricsToJson(shuffleWriteMetrics: ShuffleWriteMetrics): JValue = { ("Shuffle Bytes Written" -> shuffleWriteMetrics.shuffleBytesWritten) ~ - ("Shuffle Write Time" -> shuffleWriteMetrics.shuffleWriteTime) + ("Shuffle Write Time" -> shuffleWriteMetrics.shuffleWriteTime) ~ + ("Shuffle Records Written" -> shuffleWriteMetrics.shuffleRecordsWritten) } def inputMetricsToJson(inputMetrics: InputMetrics): JValue = { ("Data Read Method" -> inputMetrics.readMethod.toString) ~ - ("Bytes Read" -> inputMetrics.bytesRead) + ("Bytes Read" -> inputMetrics.bytesRead) ~ + ("Records Read" -> inputMetrics.recordsRead) } def outputMetricsToJson(outputMetrics: OutputMetrics): JValue = { ("Data Write Method" -> outputMetrics.writeMethod.toString) ~ - ("Bytes Written" -> outputMetrics.bytesWritten) + ("Bytes Written" -> outputMetrics.bytesWritten) ~ + ("Records Written" -> outputMetrics.recordsWritten) } def taskEndReasonToJson(taskEndReason: TaskEndReason): JValue = { @@ -379,7 +388,8 @@ private[spark] object JsonProtocol { def executorInfoToJson(executorInfo: ExecutorInfo): JValue = { ("Host" -> executorInfo.executorHost) ~ - ("Total Cores" -> executorInfo.totalCores) + ("Total Cores" -> executorInfo.totalCores) ~ + ("Log Urls" -> mapToJson(executorInfo.logUrlMap)) } /** ------------------------------ * @@ -492,6 +502,8 @@ private[spark] object JsonProtocol { def jobStartFromJson(json: JValue): SparkListenerJobStart = { val jobId = (json \ "Job ID").extract[Int] + val submissionTime = + Utils.jsonOption(json \ "Submission Time").map(_.extract[Long]).getOrElse(-1L) val stageIds = (json \ "Stage IDs").extract[List[JValue]].map(_.extract[Int]) val properties = propertiesFromJson(json \ "Properties") // The "Stage Infos" field was added in Spark 1.2.0 @@ -499,13 +511,15 @@ private[spark] object JsonProtocol { .map(_.extract[Seq[JValue]].map(stageInfoFromJson)).getOrElse { stageIds.map(id => new StageInfo(id, 0, "unknown", 0, Seq.empty, "unknown")) } - SparkListenerJobStart(jobId, stageInfos, properties) + SparkListenerJobStart(jobId, submissionTime, stageInfos, properties) } def jobEndFromJson(json: JValue): SparkListenerJobEnd = { val jobId = (json \ "Job ID").extract[Int] + val completionTime = + Utils.jsonOption(json \ "Completion Time").map(_.extract[Long]).getOrElse(-1L) val jobResult = jobResultFromJson(json \ "Job Result") - SparkListenerJobEnd(jobId, jobResult) + SparkListenerJobEnd(jobId, completionTime, jobResult) } def environmentUpdateFromJson(json: JValue): SparkListenerEnvironmentUpdate = { @@ -547,14 +561,17 @@ private[spark] object JsonProtocol { } def executorAddedFromJson(json: JValue): SparkListenerExecutorAdded = { + val time = (json \ "Timestamp").extract[Long] val executorId = (json \ "Executor ID").extract[String] val executorInfo = executorInfoFromJson(json \ "Executor Info") - SparkListenerExecutorAdded(executorId, executorInfo) + SparkListenerExecutorAdded(time, executorId, executorInfo) } def executorRemovedFromJson(json: JValue): SparkListenerExecutorRemoved = { + val time = (json \ "Timestamp").extract[Long] val executorId = (json \ "Executor ID").extract[String] - SparkListenerExecutorRemoved(executorId) + val reason = (json \ "Removed Reason").extract[String] + SparkListenerExecutorRemoved(time, executorId, reason) } /** --------------------------------------------------------------------- * @@ -625,14 +642,14 @@ private[spark] object JsonProtocol { return TaskMetrics.empty } val metrics = new TaskMetrics - metrics.hostname = (json \ "Host Name").extract[String] - metrics.executorDeserializeTime = (json \ "Executor Deserialize Time").extract[Long] - metrics.executorRunTime = (json \ "Executor Run Time").extract[Long] - metrics.resultSize = (json \ "Result Size").extract[Long] - metrics.jvmGCTime = (json \ "JVM GC Time").extract[Long] - metrics.resultSerializationTime = (json \ "Result Serialization Time").extract[Long] - metrics.memoryBytesSpilled = (json \ "Memory Bytes Spilled").extract[Long] - metrics.diskBytesSpilled = (json \ "Disk Bytes Spilled").extract[Long] + metrics.setHostname((json \ "Host Name").extract[String]) + metrics.setExecutorDeserializeTime((json \ "Executor Deserialize Time").extract[Long]) + metrics.setExecutorRunTime((json \ "Executor Run Time").extract[Long]) + metrics.setResultSize((json \ "Result Size").extract[Long]) + metrics.setJvmGCTime((json \ "JVM GC Time").extract[Long]) + metrics.setResultSerializationTime((json \ "Result Serialization Time").extract[Long]) + metrics.incMemoryBytesSpilled((json \ "Memory Bytes Spilled").extract[Long]) + metrics.incDiskBytesSpilled((json \ "Disk Bytes Spilled").extract[Long]) metrics.setShuffleReadMetrics( Utils.jsonOption(json \ "Shuffle Read Metrics").map(shuffleReadMetricsFromJson)) metrics.shuffleWriteMetrics = @@ -654,31 +671,37 @@ private[spark] object JsonProtocol { def shuffleReadMetricsFromJson(json: JValue): ShuffleReadMetrics = { val metrics = new ShuffleReadMetrics - metrics.remoteBlocksFetched = (json \ "Remote Blocks Fetched").extract[Int] - metrics.localBlocksFetched = (json \ "Local Blocks Fetched").extract[Int] - metrics.fetchWaitTime = (json \ "Fetch Wait Time").extract[Long] - metrics.remoteBytesRead = (json \ "Remote Bytes Read").extract[Long] + metrics.incRemoteBlocksFetched((json \ "Remote Blocks Fetched").extract[Int]) + metrics.incLocalBlocksFetched((json \ "Local Blocks Fetched").extract[Int]) + metrics.incFetchWaitTime((json \ "Fetch Wait Time").extract[Long]) + metrics.incRemoteBytesRead((json \ "Remote Bytes Read").extract[Long]) + metrics.incLocalBytesRead((json \ "Local Bytes Read").extractOpt[Long].getOrElse(0)) + metrics.incRecordsRead((json \ "Total Records Read").extractOpt[Long].getOrElse(0)) metrics } def shuffleWriteMetricsFromJson(json: JValue): ShuffleWriteMetrics = { val metrics = new ShuffleWriteMetrics - metrics.shuffleBytesWritten = (json \ "Shuffle Bytes Written").extract[Long] - metrics.shuffleWriteTime = (json \ "Shuffle Write Time").extract[Long] + metrics.incShuffleBytesWritten((json \ "Shuffle Bytes Written").extract[Long]) + metrics.incShuffleWriteTime((json \ "Shuffle Write Time").extract[Long]) + metrics.setShuffleRecordsWritten((json \ "Shuffle Records Written") + .extractOpt[Long].getOrElse(0)) metrics } def inputMetricsFromJson(json: JValue): InputMetrics = { val metrics = new InputMetrics( DataReadMethod.withName((json \ "Data Read Method").extract[String])) - metrics.addBytesRead((json \ "Bytes Read").extract[Long]) + metrics.incBytesRead((json \ "Bytes Read").extract[Long]) + metrics.incRecordsRead((json \ "Records Read").extractOpt[Long].getOrElse(0)) metrics } def outputMetricsFromJson(json: JValue): OutputMetrics = { val metrics = new OutputMetrics( DataWriteMethod.withName((json \ "Data Write Method").extract[String])) - metrics.bytesWritten = (json \ "Bytes Written").extract[Long] + metrics.setBytesWritten((json \ "Bytes Written").extract[Long]) + metrics.setRecordsWritten((json \ "Records Written").extractOpt[Long].getOrElse(0)) metrics } @@ -781,7 +804,8 @@ private[spark] object JsonProtocol { def executorInfoFromJson(json: JValue): ExecutorInfo = { val executorHost = (json \ "Host").extract[String] val totalCores = (json \ "Total Cores").extract[Int] - new ExecutorInfo(executorHost, totalCores) + val logUrls = mapFromJson(json \ "Log Urls").toMap + new ExecutorInfo(executorHost, totalCores, logUrls) } /** -------------------------------- * diff --git a/core/src/main/scala/org/apache/spark/util/ListenerBus.scala b/core/src/main/scala/org/apache/spark/util/ListenerBus.scala new file mode 100644 index 0000000000000..d60b8b9a31a9b --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/ListenerBus.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import java.util.concurrent.CopyOnWriteArrayList + +import scala.util.control.NonFatal + +import org.apache.spark.Logging + +/** + * An event bus which posts events to its listeners. + */ +private[spark] trait ListenerBus[L <: AnyRef, E] extends Logging { + + // Marked `private[spark]` for access in tests. + private[spark] val listeners = new CopyOnWriteArrayList[L] + + /** + * Add a listener to listen events. This method is thread-safe and can be called in any thread. + */ + final def addListener(listener: L) { + listeners.add(listener) + } + + /** + * Post the event to all registered listeners. The `postToAll` caller should guarantee calling + * `postToAll` in the same thread for all events. + */ + final def postToAll(event: E): Unit = { + // JavaConversions will create a JIterableWrapper if we use some Scala collection functions. + // However, this method will be called frequently. To avoid the wrapper cost, here ewe use + // Java Iterator directly. + val iter = listeners.iterator + while (iter.hasNext) { + val listener = iter.next() + try { + onPostEvent(listener, event) + } catch { + case NonFatal(e) => + logError(s"Listener ${Utils.getFormattedClassName(listener)} threw an exception", e) + } + } + } + + /** + * Post an event to the specified listener. `onPostEvent` is guaranteed to be called in the same + * thread. + */ + def onPostEvent(listener: L, event: E): Unit + +} diff --git a/core/src/main/scala/org/apache/spark/util/ManualClock.scala b/core/src/main/scala/org/apache/spark/util/ManualClock.scala new file mode 100644 index 0000000000000..cf89c1782fd67 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/ManualClock.scala @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +/** + * A `Clock` whose time can be manually set and modified. Its reported time does not change + * as time elapses, but only as its time is modified by callers. This is mainly useful for + * testing. + * + * @param time initial time (in milliseconds since the epoch) + */ +private[spark] class ManualClock(private var time: Long) extends Clock { + + /** + * @return `ManualClock` with initial time 0 + */ + def this() = this(0L) + + def getTimeMillis(): Long = + synchronized { + time + } + + /** + * @param timeToSet new time (in milliseconds) that the clock should represent + */ + def setTime(timeToSet: Long) = + synchronized { + time = timeToSet + notifyAll() + } + + /** + * @param timeToAdd time (in milliseconds) to add to the clock's time + */ + def advance(timeToAdd: Long) = + synchronized { + time += timeToAdd + notifyAll() + } + + /** + * @param targetTime block until the clock time is set or advanced to at least this time + * @return current time reported by the clock when waiting finishes + */ + def waitTillTime(targetTime: Long): Long = + synchronized { + while (time < targetTime) { + wait(100) + } + getTimeMillis() + } + +} diff --git a/core/src/main/scala/org/apache/spark/util/MutableURLClassLoader.scala b/core/src/main/scala/org/apache/spark/util/MutableURLClassLoader.scala new file mode 100644 index 0000000000000..d9c7103b2f3bf --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/MutableURLClassLoader.scala @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import java.net.{URLClassLoader, URL} +import java.util.Enumeration +import java.util.concurrent.ConcurrentHashMap + +import scala.collection.JavaConversions._ + +import org.apache.spark.util.ParentClassLoader + +/** + * URL class loader that exposes the `addURL` and `getURLs` methods in URLClassLoader. + */ +private[spark] class MutableURLClassLoader(urls: Array[URL], parent: ClassLoader) + extends URLClassLoader(urls, parent) { + + override def addURL(url: URL): Unit = { + super.addURL(url) + } + + override def getURLs(): Array[URL] = { + super.getURLs() + } + +} + +/** + * A mutable class loader that gives preference to its own URLs over the parent class loader + * when loading classes and resources. + */ +private[spark] class ChildFirstURLClassLoader(urls: Array[URL], parent: ClassLoader) + extends MutableURLClassLoader(urls, null) { + + private val parentClassLoader = new ParentClassLoader(parent) + + /** + * Used to implement fine-grained class loading locks similar to what is done by Java 7. This + * prevents deadlock issues when using non-hierarchical class loaders. + * + * Note that due to Java 6 compatibility (and some issues with implementing class loaders in + * Scala), Java 7's `ClassLoader.registerAsParallelCapable` method is not called. + */ + private val locks = new ConcurrentHashMap[String, Object]() + + override def loadClass(name: String, resolve: Boolean): Class[_] = { + var lock = locks.get(name) + if (lock == null) { + val newLock = new Object() + lock = locks.putIfAbsent(name, newLock) + if (lock == null) { + lock = newLock + } + } + + lock.synchronized { + try { + super.loadClass(name, resolve) + } catch { + case e: ClassNotFoundException => + parentClassLoader.loadClass(name, resolve) + } + } + } + + override def getResource(name: String): URL = { + val url = super.findResource(name) + val res = if (url != null) url else parentClassLoader.getResource(name) + res + } + + override def getResources(name: String): Enumeration[URL] = { + val urls = super.findResources(name) + val res = + if (urls != null && urls.hasMoreElements()) { + urls + } else { + parentClassLoader.getResources(name) + } + res + } + + override def addURL(url: URL) { + super.addURL(url) + } + +} diff --git a/core/src/main/scala/org/apache/spark/util/ParentClassLoader.scala b/core/src/main/scala/org/apache/spark/util/ParentClassLoader.scala index 3abc12681fe9a..6d8d9e8da3678 100644 --- a/core/src/main/scala/org/apache/spark/util/ParentClassLoader.scala +++ b/core/src/main/scala/org/apache/spark/util/ParentClassLoader.scala @@ -18,7 +18,7 @@ package org.apache.spark.util /** - * A class loader which makes findClass accesible to the child + * A class loader which makes some protected methods in ClassLoader accesible. */ private[spark] class ParentClassLoader(parent: ClassLoader) extends ClassLoader(parent) { @@ -29,4 +29,9 @@ private[spark] class ParentClassLoader(parent: ClassLoader) extends ClassLoader( override def loadClass(name: String): Class[_] = { super.loadClass(name) } + + override def loadClass(name: String, resolve: Boolean): Class[_] = { + super.loadClass(name, resolve) + } + } diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 2c04e4ddfbcb7..4644088f19f4b 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -21,8 +21,9 @@ import java.io._ import java.lang.management.ManagementFactory import java.net._ import java.nio.ByteBuffer -import java.util.concurrent.{ConcurrentHashMap, Executors, ThreadFactory, ThreadPoolExecutor} -import java.util.{Locale, Properties, Random, UUID} +import java.util.{Properties, Locale, Random, UUID} +import java.util.concurrent.{ThreadFactory, ConcurrentHashMap, Executors, ThreadPoolExecutor} +import javax.net.ssl.HttpsURLConnection import scala.collection.JavaConversions._ import scala.collection.Map @@ -37,6 +38,7 @@ import com.google.common.util.concurrent.ThreadFactoryBuilder import org.apache.commons.lang3.SystemUtils import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, FileUtil, Path} +import org.apache.hadoop.security.UserGroupInformation import org.apache.log4j.PropertyConfigurator import org.eclipse.jetty.util.MultiException import org.json4s._ @@ -60,6 +62,9 @@ private[spark] object CallSite { private[spark] object Utils extends Logging { val random = new Random() + private val MAX_DIR_CREATION_ATTEMPTS: Int = 10 + @volatile private var localRootDirs: Array[String] = null + /** Serialize an object using Java serialization */ def serialize[T](o: T): Array[Byte] = { val bos = new ByteArrayOutputStream() @@ -209,8 +214,8 @@ private[spark] object Utils extends Logging { // Is the path already registered to be deleted via a shutdown hook ? def hasShutdownDeleteTachyonDir(file: TachyonFile): Boolean = { val absolutePath = file.getPath() - shutdownDeletePaths.synchronized { - shutdownDeletePaths.contains(absolutePath) + shutdownDeleteTachyonPaths.synchronized { + shutdownDeleteTachyonPaths.contains(absolutePath) } } @@ -246,13 +251,28 @@ private[spark] object Utils extends Logging { retval } + /** + * JDK equivalent of `chmod 700 file`. + * + * @param file the file whose permissions will be modified + * @return true if the permissions were successfully changed, false otherwise. + */ + def chmod700(file: File): Boolean = { + file.setReadable(false, false) && + file.setReadable(true, true) && + file.setWritable(false, false) && + file.setWritable(true, true) && + file.setExecutable(false, false) && + file.setExecutable(true, true) + } + /** * Create a directory inside the given parent directory. The directory is guaranteed to be * newly created, and is not marked for automatic deletion. */ - def createDirectory(root: String): File = { + def createDirectory(root: String, namePrefix: String = "spark"): File = { var attempts = 0 - val maxAttempts = 10 + val maxAttempts = MAX_DIR_CREATION_ATTEMPTS var dir: File = null while (dir == null) { attempts += 1 @@ -261,7 +281,7 @@ private[spark] object Utils extends Logging { maxAttempts + " attempts!") } try { - dir = new File(root, "spark-" + UUID.randomUUID.toString) + dir = new File(root, namePrefix + "-" + UUID.randomUUID.toString) if (dir.exists() || !dir.mkdirs()) { dir = null } @@ -275,8 +295,10 @@ private[spark] object Utils extends Logging { * Create a temporary directory inside the given parent directory. The directory will be * automatically deleted when the VM shuts down. */ - def createTempDir(root: String = System.getProperty("java.io.tmpdir")): File = { - val dir = createDirectory(root) + def createTempDir( + root: String = System.getProperty("java.io.tmpdir"), + namePrefix: String = "spark"): File = { + val dir = createDirectory(root, namePrefix) registerShutdownDeleteDir(dir) dir } @@ -359,8 +381,10 @@ private[spark] object Utils extends Logging { } /** - * Download a file to target directory. Supports fetching the file in a variety of ways, - * including HTTP, HDFS and files on a standard filesystem, based on the URL parameter. + * Download a file or directory to target directory. Supports fetching the file in a variety of + * ways, including HTTP, Hadoop-compatible filesystems, and files on a standard filesystem, based + * on the URL parameter. Fetching directories is only supported from Hadoop-compatible + * filesystems. * * If `useCache` is true, first attempts to fetch the file to a local cache that's shared * across executors running the same application. `useCache` is used mainly for @@ -410,13 +434,19 @@ private[spark] object Utils extends Logging { // Decompress the file if it's a .tar or .tar.gz if (fileName.endsWith(".tar.gz") || fileName.endsWith(".tgz")) { logInfo("Untarring " + fileName) - Utils.execute(Seq("tar", "-xzf", fileName), targetDir) + executeAndGetOutput(Seq("tar", "-xzf", fileName), targetDir) } else if (fileName.endsWith(".tar")) { logInfo("Untarring " + fileName) - Utils.execute(Seq("tar", "-xf", fileName), targetDir) + executeAndGetOutput(Seq("tar", "-xf", fileName), targetDir) } // Make the file executable - That's necessary for scripts FileUtil.chmod(targetFile.getAbsolutePath, "a+x") + + // Windows does not grant read permission by default to non-admin users + // Add read permission to owner explicitly + if (isWindows) { + FileUtil.chmod(targetFile.getAbsolutePath, "u+r") + } } /** @@ -429,7 +459,6 @@ private[spark] object Utils extends Logging { * * @param url URL that `sourceFile` originated from, for logging purposes. * @param in InputStream to download. - * @param tempFile File path to download `in` to. * @param destFile File path to move `tempFile` to. * @param fileOverwrite Whether to delete/overwrite an existing `destFile` that does not match * `sourceFile` @@ -437,9 +466,11 @@ private[spark] object Utils extends Logging { private def downloadFile( url: String, in: InputStream, - tempFile: File, destFile: File, fileOverwrite: Boolean): Unit = { + val tempFile = File.createTempFile("fetchFileTemp", null, + new File(destFile.getParentFile.getAbsolutePath)) + logInfo(s"Fetching $url to $tempFile") try { val out = new FileOutputStream(tempFile) @@ -478,7 +509,7 @@ private[spark] object Utils extends Logging { removeSourceFile: Boolean = false): Unit = { if (destFile.exists) { - if (!Files.equal(sourceFile, destFile)) { + if (!filesEqualRecursive(sourceFile, destFile)) { if (fileOverwrite) { logInfo( s"File $destFile exists and does not match contents of $url, replacing it with $url" @@ -513,13 +544,44 @@ private[spark] object Utils extends Logging { Files.move(sourceFile, destFile) } else { logInfo(s"Copying ${sourceFile.getAbsolutePath} to ${destFile.getAbsolutePath}") - Files.copy(sourceFile, destFile) + copyRecursive(sourceFile, destFile) + } + } + + private def filesEqualRecursive(file1: File, file2: File): Boolean = { + if (file1.isDirectory && file2.isDirectory) { + val subfiles1 = file1.listFiles() + val subfiles2 = file2.listFiles() + if (subfiles1.size != subfiles2.size) { + return false + } + subfiles1.sortBy(_.getName).zip(subfiles2.sortBy(_.getName)).forall { + case (f1, f2) => filesEqualRecursive(f1, f2) + } + } else if (file1.isFile && file2.isFile) { + Files.equal(file1, file2) + } else { + false + } + } + + private def copyRecursive(source: File, dest: File): Unit = { + if (source.isDirectory) { + if (!dest.mkdir()) { + throw new IOException(s"Failed to create directory ${dest.getPath}") + } + val subfiles = source.listFiles() + subfiles.foreach(f => copyRecursive(f, new File(dest, f.getName))) + } else { + Files.copy(source, dest) } } /** - * Download a file to target directory. Supports fetching the file in a variety of ways, - * including HTTP, HDFS and files on a standard filesystem, based on the URL parameter. + * Download a file or directory to target directory. Supports fetching the file in a variety of + * ways, including HTTP, Hadoop-compatible filesystems, and files on a standard filesystem, based + * on the URL parameter. Fetching directories is only supported from Hadoop-compatible + * filesystems. * * Throws SparkException if the target file already exists and has different contents than * the requested file. @@ -531,14 +593,11 @@ private[spark] object Utils extends Logging { conf: SparkConf, securityMgr: SecurityManager, hadoopConf: Configuration) { - val tempFile = File.createTempFile("fetchFileTemp", null, new File(targetDir.getAbsolutePath)) val targetFile = new File(targetDir, filename) val uri = new URI(url) val fileOverwrite = conf.getBoolean("spark.files.overwrite", defaultValue = false) Option(uri.getScheme).getOrElse("file") match { case "http" | "https" | "ftp" => - logInfo("Fetching " + url + " to " + tempFile) - var uc: URLConnection = null if (securityMgr.isAuthenticationEnabled()) { logDebug("fetchFile with security enabled") @@ -549,23 +608,51 @@ private[spark] object Utils extends Logging { logDebug("fetchFile not using security") uc = new URL(url).openConnection() } + Utils.setupSecureURLConnection(uc, securityMgr) val timeout = conf.getInt("spark.files.fetchTimeout", 60) * 1000 uc.setConnectTimeout(timeout) uc.setReadTimeout(timeout) uc.connect() val in = uc.getInputStream() - downloadFile(url, in, tempFile, targetFile, fileOverwrite) + downloadFile(url, in, targetFile, fileOverwrite) case "file" => // In the case of a local file, copy the local file to the target directory. // Note the difference between uri vs url. val sourceFile = if (uri.isAbsolute) new File(uri) else new File(url) copyFile(url, sourceFile, targetFile, fileOverwrite) case _ => - // Use the Hadoop filesystem library, which supports file://, hdfs://, s3://, and others val fs = getHadoopFileSystem(uri, hadoopConf) - val in = fs.open(new Path(uri)) - downloadFile(url, in, tempFile, targetFile, fileOverwrite) + val path = new Path(uri) + fetchHcfsFile(path, new File(targetDir, path.getName), fs, conf, hadoopConf, fileOverwrite) + } + } + + /** + * Fetch a file or directory from a Hadoop-compatible filesystem. + * + * Visible for testing + */ + private[spark] def fetchHcfsFile( + path: Path, + targetDir: File, + fs: FileSystem, + conf: SparkConf, + hadoopConf: Configuration, + fileOverwrite: Boolean): Unit = { + if (!targetDir.mkdir()) { + throw new IOException(s"Failed to create directory ${targetDir.getPath}") + } + fs.listStatus(path).foreach { fileStatus => + val innerPath = fileStatus.getPath + if (fileStatus.isDir) { + fetchHcfsFile(innerPath, new File(targetDir, innerPath.getName), fs, conf, hadoopConf, + fileOverwrite) + } else { + val in = fs.open(innerPath) + val targetFile = new File(targetDir, innerPath.getName) + downloadFile(innerPath.toString, in, targetFile, fileOverwrite) + } } } @@ -597,28 +684,56 @@ private[spark] object Utils extends Logging { * and returns only the directories that exist / could be created. * * If no directories could be created, this will return an empty list. + * + * This method will cache the local directories for the application when it's first invoked. + * So calling it multiple times with a different configuration will always return the same + * set of directories. */ private[spark] def getOrCreateLocalRootDirs(conf: SparkConf): Array[String] = { - val confValue = if (isRunningInYarnContainer(conf)) { + if (localRootDirs == null) { + this.synchronized { + if (localRootDirs == null) { + localRootDirs = getOrCreateLocalRootDirsImpl(conf) + } + } + } + localRootDirs + } + + private def getOrCreateLocalRootDirsImpl(conf: SparkConf): Array[String] = { + if (isRunningInYarnContainer(conf)) { // If we are in yarn mode, systems can have different disk layouts so we must set it - // to what Yarn on this system said was available. - getYarnLocalDirs(conf) + // to what Yarn on this system said was available. Note this assumes that Yarn has + // created the directories already, and that they are secured so that only the + // user has access to them. + getYarnLocalDirs(conf).split(",") + } else if (conf.getenv("SPARK_EXECUTOR_DIRS") != null) { + conf.getenv("SPARK_EXECUTOR_DIRS").split(File.pathSeparator) } else { - Option(conf.getenv("SPARK_LOCAL_DIRS")).getOrElse( - conf.get("spark.local.dir", System.getProperty("java.io.tmpdir"))) - } - val rootDirs = confValue.split(',') - logDebug(s"Getting/creating local root dirs at '$confValue'") - - rootDirs.flatMap { rootDir => - val localDir: File = new File(rootDir) - val foundLocalDir = localDir.exists || localDir.mkdirs() - if (!foundLocalDir) { - logError(s"Failed to create local root dir in $rootDir. Ignoring this directory.") - None - } else { - Some(rootDir) - } + // In non-Yarn mode (or for the driver in yarn-client mode), we cannot trust the user + // configuration to point to a secure directory. So create a subdirectory with restricted + // permissions under each listed directory. + Option(conf.getenv("SPARK_LOCAL_DIRS")) + .getOrElse(conf.get("spark.local.dir", System.getProperty("java.io.tmpdir"))) + .split(",") + .flatMap { root => + try { + val rootDir = new File(root) + if (rootDir.exists || rootDir.mkdirs()) { + val dir = createTempDir(root) + chmod700(dir) + Some(dir.getAbsolutePath) + } else { + logError(s"Failed to create dir in $root. Ignoring this directory.") + None + } + } catch { + case e: IOException => + logError(s"Failed to create local root dir in $root. Ignoring this directory.") + None + } + } + .toArray } } @@ -637,6 +752,11 @@ private[spark] object Utils extends Logging { localDirs } + /** Used by unit tests. Do not call from other places. */ + private[spark] def clearLocalRootDirs(): Unit = { + localRootDirs = null + } + /** * Shuffle the elements of a collection into a random order, returning the * result in a new collection. Unlike scala.util.Random.shuffle, this method @@ -956,25 +1076,25 @@ private[spark] object Utils extends Logging { } /** - * Execute a command in the given working directory, throwing an exception if it completes - * with an exit code other than 0. + * Execute a command and return the process running the command. */ - def execute(command: Seq[String], workingDir: File) { - val process = new ProcessBuilder(command: _*) - .directory(workingDir) - .redirectErrorStream(true) - .start() - new Thread("read stdout for " + command(0)) { - override def run() { - for (line <- Source.fromInputStream(process.getInputStream).getLines()) { - System.err.println(line) - } - } - }.start() - val exitCode = process.waitFor() - if (exitCode != 0) { - throw new SparkException("Process " + command + " exited with code " + exitCode) + def executeCommand( + command: Seq[String], + workingDir: File = new File("."), + extraEnvironment: Map[String, String] = Map.empty, + redirectStderr: Boolean = true): Process = { + val builder = new ProcessBuilder(command: _*).directory(workingDir) + val environment = builder.environment() + for ((key, value) <- extraEnvironment) { + environment.put(key, value) + } + val process = builder.start() + if (redirectStderr) { + val threadName = "redirect stderr for command " + command(0) + def log(s: String): Unit = logInfo(s) + processStreamByLine(threadName, process.getErrorStream, log) } + process } /** @@ -983,31 +1103,13 @@ private[spark] object Utils extends Logging { def executeAndGetOutput( command: Seq[String], workingDir: File = new File("."), - extraEnvironment: Map[String, String] = Map.empty): String = { - val builder = new ProcessBuilder(command: _*) - .directory(workingDir) - val environment = builder.environment() - for ((key, value) <- extraEnvironment) { - environment.put(key, value) - } - - val process = builder.start() - new Thread("read stderr for " + command(0)) { - override def run() { - for (line <- Source.fromInputStream(process.getErrorStream).getLines()) { - logInfo(line) - } - } - }.start() + extraEnvironment: Map[String, String] = Map.empty, + redirectStderr: Boolean = true): String = { + val process = executeCommand(command, workingDir, extraEnvironment, redirectStderr) val output = new StringBuffer - val stdoutThread = new Thread("read stdout for " + command(0)) { - override def run() { - for (line <- Source.fromInputStream(process.getInputStream).getLines()) { - output.append(line) - } - } - } - stdoutThread.start() + val threadName = "read stdout for " + command(0) + def appendToOutput(s: String): Unit = output.append(s) + val stdoutThread = processStreamByLine(threadName, process.getInputStream, appendToOutput) val exitCode = process.waitFor() stdoutThread.join() // Wait for it to finish reading output if (exitCode != 0) { @@ -1017,6 +1119,25 @@ private[spark] object Utils extends Logging { output.toString } + /** + * Return and start a daemon thread that processes the content of the input stream line by line. + */ + def processStreamByLine( + threadName: String, + inputStream: InputStream, + processLine: String => Unit): Thread = { + val t = new Thread(threadName) { + override def run() { + for (line <- Source.fromInputStream(inputStream).getLines()) { + processLine(line) + } + } + } + t.setDaemon(true) + t.start() + t + } + /** * Execute a block of code that evaluates to Unit, forwarding any uncaught exceptions to the * default UncaughtExceptionHandler @@ -1066,9 +1187,9 @@ private[spark] object Utils extends Logging { // finding the call site of a method. val SPARK_CORE_CLASS_REGEX = """^org\.apache\.spark(\.api\.java)?(\.util)?(\.rdd)?(\.broadcast)?\.[A-Z]""".r - val SCALA_CLASS_REGEX = """^scala""".r + val SCALA_CORE_CLASS_PREFIX = "scala" val isSparkCoreClass = SPARK_CORE_CLASS_REGEX.findFirstIn(className).isDefined - val isScalaClass = SCALA_CLASS_REGEX.findFirstIn(className).isDefined + val isScalaClass = className.startsWith(SCALA_CORE_CLASS_PREFIX) // If the class is a Spark internal class or a Scala class, then exclude. isSparkCoreClass || isScalaClass } @@ -1311,9 +1432,14 @@ private[spark] object Utils extends Logging { hashAbs } - /** Returns a copy of the system properties that is thread-safe to iterator over. */ - def getSystemProperties(): Map[String, String] = { - System.getProperties.clone().asInstanceOf[java.util.Properties].toMap[String, String] + /** Returns the system properties map that is thread-safe to iterator over. It gets the + * properties which have been set explicitly, as well as those for which only a default value + * has been defined. */ + def getSystemProperties: Map[String, String] = { + val sysProps = for (key <- System.getProperties.stringPropertyNames()) yield + (key, System.getProperty(key)) + + sysProps.toMap } /** @@ -1779,6 +1905,20 @@ private[spark] object Utils extends Logging { PropertyConfigurator.configure(pro) } + /** + * If the given URL connection is HttpsURLConnection, it sets the SSL socket factory and + * the host verifier from the given security manager. + */ + def setupSecureURLConnection(urlConnection: URLConnection, sm: SecurityManager): URLConnection = { + urlConnection match { + case https: HttpsURLConnection => + sm.sslSocketFactory.foreach(https.setSSLSocketFactory) + sm.hostnameVerifier.foreach(https.setHostnameVerifier) + https + case connection => connection + } + } + def invoke( clazz: Class[_], obj: AnyRef, @@ -1871,6 +2011,16 @@ private[spark] object Utils extends Logging { throw new SparkException("Invalid master URL: " + sparkUrl, e) } } + + /** + * Returns the current user name. This is the currently logged in user, unless that's been + * overridden by the `SPARK_USER` environment variable. + */ + def getCurrentUserName(): String = { + Option(System.getenv("SPARK_USER")) + .getOrElse(UserGroupInformation.getCurrentUser().getUserName()) + } + } /** diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala index 8a0f5a602de12..fc7e86e297540 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala @@ -387,6 +387,15 @@ class ExternalAppendOnlyMap[K, V, C]( private var batchIndex = 0 // Which batch we're in private var fileStream: FileInputStream = null + @volatile private var closed = false + + // A volatile variable to remember which DeserializationStream is using. Need to set it when we + // open a DeserializationStream. But we should use `deserializeStream` rather than + // `deserializeStreamToBeClosed` to read the content because touching a volatile variable will + // reduce the performance. It must be volatile so that we can see its correct value in the + // `finalize` method, which could run in any thread. + @volatile private var deserializeStreamToBeClosed: DeserializationStream = null + // An intermediate stream that reads from exactly one batch // This guards against pre-fetching and other arbitrary behavior of higher level streams private var deserializeStream = nextBatchStream() @@ -401,6 +410,7 @@ class ExternalAppendOnlyMap[K, V, C]( // we're still in a valid batch. if (batchIndex < batchOffsets.length - 1) { if (deserializeStream != null) { + deserializeStreamToBeClosed = null deserializeStream.close() fileStream.close() deserializeStream = null @@ -419,7 +429,11 @@ class ExternalAppendOnlyMap[K, V, C]( val bufferedStream = new BufferedInputStream(ByteStreams.limit(fileStream, end - start)) val compressedStream = blockManager.wrapForCompression(blockId, bufferedStream) - ser.deserializeStream(compressedStream) + // Before returning the stream, assign it to `deserializeStreamToBeClosed` so that we can + // close it in `finalize` and also avoid to touch the volatile `deserializeStreamToBeClosed` + // during reading the (K, C) pairs. + deserializeStreamToBeClosed = ser.deserializeStream(compressedStream) + deserializeStreamToBeClosed } else { // No more batches left cleanup() @@ -468,14 +482,34 @@ class ExternalAppendOnlyMap[K, V, C]( item } - // TODO: Ensure this gets called even if the iterator isn't drained. - private def cleanup() { - batchIndex = batchOffsets.length // Prevent reading any other batch - val ds = deserializeStream - deserializeStream = null - fileStream = null - ds.close() - file.delete() + // TODO: Now only use `finalize` to ensure `close` gets called to clean up the resources. In the + // future, we need some mechanism to ensure this gets called once the resources are not used. + private def cleanup(): Unit = { + if (!closed) { + closed = true + batchIndex = batchOffsets.length // Prevent reading any other batch + fileStream = null + try { + val ds = deserializeStreamToBeClosed + deserializeStreamToBeClosed = null + deserializeStream = null + if (ds != null) { + ds.close() + } + } finally { + if (file.exists()) { + file.delete() + } + } + } + } + + override def finalize(): Unit = { + try { + cleanup() + } finally { + super.finalize() + } } } diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index 15bda1c9cc29c..d69f2d9048055 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -723,6 +723,7 @@ private[spark] class ExternalSorter[K, V, C]( partitionWriters.foreach(_.commitAndClose()) var out: FileOutputStream = null var in: FileInputStream = null + val writeStartTime = System.nanoTime try { out = new FileOutputStream(outputFile, true) for (i <- 0 until numPartitions) { @@ -739,6 +740,8 @@ private[spark] class ExternalSorter[K, V, C]( if (in != null) { in.close() } + context.taskMetrics.shuffleWriteMetrics.foreach( + _.incShuffleWriteTime(System.nanoTime - writeStartTime)) } } else { // Either we're not bypassing merge-sort or we have only in-memory data; get an iterator by @@ -757,12 +760,13 @@ private[spark] class ExternalSorter[K, V, C]( } } - context.taskMetrics.memoryBytesSpilled += memoryBytesSpilled - context.taskMetrics.diskBytesSpilled += diskBytesSpilled + context.taskMetrics.incMemoryBytesSpilled(memoryBytesSpilled) + context.taskMetrics.incDiskBytesSpilled(diskBytesSpilled) context.taskMetrics.shuffleWriteMetrics.filter(_ => bypassMergeSort).foreach { m => if (curWriteMetrics != null) { - m.shuffleBytesWritten += curWriteMetrics.shuffleBytesWritten - m.shuffleWriteTime += curWriteMetrics.shuffleWriteTime + m.incShuffleBytesWritten(curWriteMetrics.shuffleBytesWritten) + m.incShuffleWriteTime(curWriteMetrics.shuffleWriteTime) + m.incShuffleRecordsWritten(curWriteMetrics.shuffleRecordsWritten) } } diff --git a/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala b/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala index 9f54312074856..747ecf075a397 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala @@ -42,9 +42,6 @@ private[spark] trait Spillable[C] extends Logging { // Memory manager that can be used to acquire/release memory private[this] val shuffleMemoryManager = SparkEnv.get.shuffleMemoryManager - // Threshold for `elementsRead` before we start tracking this collection's memory usage - private[this] val trackMemoryThreshold = 1000 - // Initial threshold for the size of a collection before we start tracking its memory usage // Exposed for testing private[this] val initialMemoryThreshold: Long = @@ -72,8 +69,7 @@ private[spark] trait Spillable[C] extends Logging { * @return true if `collection` was spilled to disk; false otherwise */ protected def maybeSpill(collection: C, currentMemory: Long): Boolean = { - if (elementsRead > trackMemoryThreshold && elementsRead % 32 == 0 && - currentMemory >= myMemoryThreshold) { + if (elementsRead % 32 == 0 && currentMemory >= myMemoryThreshold) { // Claim up to double our current memory from the shuffle memory pool val amountToRequest = 2 * currentMemory - myMemoryThreshold val granted = shuffleMemoryManager.tryToAcquire(amountToRequest) diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java index 07b1e44d04be6..74e88c767ee07 100644 --- a/core/src/test/java/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java @@ -492,6 +492,36 @@ public Integer call(Integer a, Integer b) { Assert.assertEquals(33, sum); } + @Test + public void treeReduce() { + JavaRDD rdd = sc.parallelize(Arrays.asList(-5, -4, -3, -2, -1, 1, 2, 3, 4), 10); + Function2 add = new Function2() { + @Override + public Integer call(Integer a, Integer b) { + return a + b; + } + }; + for (int depth = 1; depth <= 10; depth++) { + int sum = rdd.treeReduce(add, depth); + Assert.assertEquals(-5, sum); + } + } + + @Test + public void treeAggregate() { + JavaRDD rdd = sc.parallelize(Arrays.asList(-5, -4, -3, -2, -1, 1, 2, 3, 4), 10); + Function2 add = new Function2() { + @Override + public Integer call(Integer a, Integer b) { + return a + b; + } + }; + for (int depth = 1; depth <= 10; depth++) { + int sum = rdd.treeAggregate(0, add, add, depth); + Assert.assertEquals(-5, sum); + } + } + @SuppressWarnings("unchecked") @Test public void aggregateByKey() { @@ -606,6 +636,27 @@ public void take() { rdd.takeSample(false, 2, 42); } + @Test + public void isEmpty() { + Assert.assertTrue(sc.emptyRDD().isEmpty()); + Assert.assertTrue(sc.parallelize(new ArrayList()).isEmpty()); + Assert.assertFalse(sc.parallelize(Arrays.asList(1)).isEmpty()); + Assert.assertTrue(sc.parallelize(Arrays.asList(1, 2, 3), 3).filter( + new Function() { + @Override + public Boolean call(Integer i) { + return i < 0; + } + }).isEmpty()); + Assert.assertFalse(sc.parallelize(Arrays.asList(1, 2, 3)).filter( + new Function() { + @Override + public Boolean call(Integer i) { + return i > 1; + } + }).isEmpty()); + } + @Test public void cartesian() { JavaDoubleRDD doubleRDD = sc.parallelizeDoubles(Arrays.asList(1.0, 1.0, 2.0, 3.0, 5.0, 8.0)); @@ -657,6 +708,10 @@ public void javaDoubleRDDHistoGram() { // Test with provided buckets long[] histogram = rdd.histogram(expected_buckets); Assert.assertArrayEquals(expected_counts, histogram); + // SPARK-5744 + Assert.assertArrayEquals( + new long[] {0}, + sc.parallelizeDoubles(new ArrayList(0), 1).histogram(new double[]{0.0, 1.0})); } @Test diff --git a/core/src/test/java/org/apache/spark/util/JavaTaskCompletionListenerImpl.java b/core/src/test/java/test/org/apache/spark/JavaTaskCompletionListenerImpl.java similarity index 93% rename from core/src/test/java/org/apache/spark/util/JavaTaskCompletionListenerImpl.java rename to core/src/test/java/test/org/apache/spark/JavaTaskCompletionListenerImpl.java index e9ec700e32e15..e38bc38949d7c 100644 --- a/core/src/test/java/org/apache/spark/util/JavaTaskCompletionListenerImpl.java +++ b/core/src/test/java/test/org/apache/spark/JavaTaskCompletionListenerImpl.java @@ -15,9 +15,10 @@ * limitations under the License. */ -package org.apache.spark.util; +package test.org.apache.spark; import org.apache.spark.TaskContext; +import org.apache.spark.util.TaskCompletionListener; /** diff --git a/core/src/test/java/test/org/apache/spark/JavaTaskContextCompileCheck.java b/core/src/test/java/test/org/apache/spark/JavaTaskContextCompileCheck.java new file mode 100644 index 0000000000000..4a918f725dc91 --- /dev/null +++ b/core/src/test/java/test/org/apache/spark/JavaTaskContextCompileCheck.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test.org.apache.spark; + +import org.apache.spark.TaskContext; + +/** + * Something to make sure that TaskContext can be used in Java. + */ +public class JavaTaskContextCompileCheck { + + public static void test() { + TaskContext tc = TaskContext.get(); + + tc.isCompleted(); + tc.isInterrupted(); + tc.isRunningLocally(); + + tc.addTaskCompletionListener(new JavaTaskCompletionListenerImpl()); + + tc.attemptNumber(); + tc.partitionId(); + tc.stageId(); + tc.taskAttemptId(); + } +} diff --git a/core/src/test/resources/keystore b/core/src/test/resources/keystore new file mode 100644 index 0000000000000..f8310e39ba1e0 Binary files /dev/null and b/core/src/test/resources/keystore differ diff --git a/core/src/test/resources/test_metrics_system.properties b/core/src/test/resources/test_metrics_system.properties index 35d0bd3b8d0b8..4e8b8465696e5 100644 --- a/core/src/test/resources/test_metrics_system.properties +++ b/core/src/test/resources/test_metrics_system.properties @@ -18,7 +18,5 @@ *.sink.console.period = 10 *.sink.console.unit = seconds test.sink.console.class = org.apache.spark.metrics.sink.ConsoleSink -test.sink.dummy.class = org.apache.spark.metrics.sink.DummySink -test.source.dummy.class = org.apache.spark.metrics.source.DummySource test.sink.console.period = 20 test.sink.console.unit = minutes diff --git a/core/src/test/resources/truststore b/core/src/test/resources/truststore new file mode 100644 index 0000000000000..a6b1d46e1f391 Binary files /dev/null and b/core/src/test/resources/truststore differ diff --git a/core/src/test/resources/untrusted-keystore b/core/src/test/resources/untrusted-keystore new file mode 100644 index 0000000000000..6015b02caa128 Binary files /dev/null and b/core/src/test/resources/untrusted-keystore differ diff --git a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala index f087fc550dde3..bd0f8bdefa171 100644 --- a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark import scala.collection.mutable +import scala.ref.WeakReference import org.scalatest.FunSuite import org.scalatest.Matchers @@ -136,4 +137,23 @@ class AccumulatorSuite extends FunSuite with Matchers with LocalSparkContext { } } + test ("garbage collection") { + // Create an accumulator and let it go out of scope to test that it's properly garbage collected + sc = new SparkContext("local", "test") + var acc: Accumulable[mutable.Set[Any], Any] = sc.accumulable(new mutable.HashSet[Any]()) + val accId = acc.id + val ref = WeakReference(acc) + + // Ensure the accumulator is present + assert(ref.get.isDefined) + + // Remove the explicit reference to it and allow weak reference to get garbage collected + acc = null + System.gc() + assert(ref.get.isEmpty) + + Accumulators.remove(accId) + assert(!Accumulators.originals.get(accId).isDefined) + } + } diff --git a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala index d7d9dc7b50f30..4b25c200a695a 100644 --- a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala @@ -17,16 +17,18 @@ package org.apache.spark +import org.mockito.Mockito._ import org.scalatest.{BeforeAndAfter, FunSuite} -import org.scalatest.mock.EasyMockSugar +import org.scalatest.mock.MockitoSugar -import org.apache.spark.executor.{DataReadMethod, TaskMetrics} +import org.apache.spark.executor.DataReadMethod import org.apache.spark.rdd.RDD import org.apache.spark.storage._ // TODO: Test the CacheManager's thread-safety aspects -class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar { - var sc : SparkContext = _ +class CacheManagerSuite extends FunSuite with LocalSparkContext with BeforeAndAfter + with MockitoSugar { + var blockManager: BlockManager = _ var cacheManager: CacheManager = _ var split: Partition = _ @@ -57,10 +59,6 @@ class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar }.cache() } - after { - sc.stop() - } - test("get uncached rdd") { // Do not mock this test, because attempting to match Array[Any], which is not covariant, // in blockManager.put is a losing battle. You have been warned. @@ -75,29 +73,21 @@ class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar } test("get cached rdd") { - expecting { - val result = new BlockResult(Array(5, 6, 7).iterator, DataReadMethod.Memory, 12) - blockManager.get(RDDBlockId(0, 0)).andReturn(Some(result)) - } + val result = new BlockResult(Array(5, 6, 7).iterator, DataReadMethod.Memory, 12) + when(blockManager.get(RDDBlockId(0, 0))).thenReturn(Some(result)) - whenExecuting(blockManager) { - val context = new TaskContextImpl(0, 0, 0, 0) - val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY) - assert(value.toList === List(5, 6, 7)) - } + val context = new TaskContextImpl(0, 0, 0, 0) + val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY) + assert(value.toList === List(5, 6, 7)) } test("get uncached local rdd") { - expecting { - // Local computation should not persist the resulting value, so don't expect a put(). - blockManager.get(RDDBlockId(0, 0)).andReturn(None) - } + // Local computation should not persist the resulting value, so don't expect a put(). + when(blockManager.get(RDDBlockId(0, 0))).thenReturn(None) - whenExecuting(blockManager) { - val context = new TaskContextImpl(0, 0, 0, 0, true) - val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY) - assert(value.toList === List(1, 2, 3, 4)) - } + val context = new TaskContextImpl(0, 0, 0, 0, true) + val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY) + assert(value.toList === List(1, 2, 3, 4)) } test("verify task metrics updated correctly") { diff --git a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala index ae2ae7ed0d3aa..cdfaacee7da40 100644 --- a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala @@ -382,6 +382,10 @@ class CleanerTester( toBeCleanedBroadcstIds -= broadcastId logInfo("Broadcast" + broadcastId + " cleaned") } + + def accumCleaned(accId: Long): Unit = { + logInfo("Cleaned accId " + accId + " cleaned") + } } val MAX_VALIDATION_ATTEMPTS = 10 diff --git a/core/src/test/scala/org/apache/spark/DriverSuite.scala b/core/src/test/scala/org/apache/spark/DriverSuite.scala index 8a54360e81795..9bd5dfec8703a 100644 --- a/core/src/test/scala/org/apache/spark/DriverSuite.scala +++ b/core/src/test/scala/org/apache/spark/DriverSuite.scala @@ -28,31 +28,29 @@ import org.apache.spark.util.Utils class DriverSuite extends FunSuite with Timeouts { - test("driver should exit after finishing") { + test("driver should exit after finishing without cleanup (SPARK-530)") { val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!")) - // Regression test for SPARK-530: "Spark driver process doesn't exit after finishing" - val masters = Table(("master"), ("local"), ("local-cluster[2,1,512]")) + val masters = Table("master", "local", "local-cluster[2,1,512]") forAll(masters) { (master: String) => - failAfter(60 seconds) { - Utils.executeAndGetOutput( - Seq(s"$sparkHome/bin/spark-class", "org.apache.spark.DriverWithoutCleanup", master), - new File(sparkHome), - Map("SPARK_TESTING" -> "1", "SPARK_HOME" -> sparkHome)) - } + val process = Utils.executeCommand( + Seq(s"$sparkHome/bin/spark-class", "org.apache.spark.DriverWithoutCleanup", master), + new File(sparkHome), + Map("SPARK_TESTING" -> "1", "SPARK_HOME" -> sparkHome)) + failAfter(60 seconds) { process.waitFor() } + // Ensure we still kill the process in case it timed out + process.destroy() } } } /** - * Program that creates a Spark driver but doesn't call SparkContext.stop() or - * Sys.exit() after finishing. + * Program that creates a Spark driver but doesn't call SparkContext#stop() or + * sys.exit() after finishing. */ object DriverWithoutCleanup { def main(args: Array[String]) { Utils.configTestLog4j("INFO") - // Bind the web UI to an ephemeral port in order to avoid conflicts with other tests running on - // the same machine (we shouldn't just disable the UI here, since that might mask bugs): - val conf = new SparkConf().set("spark.ui.port", "0") + val conf = new SparkConf val sc = new SparkContext(args(0), "DriverWithoutCleanup", conf) sc.parallelize(1 to 100, 4).count() } diff --git a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala index 0e4df17c1bf87..abfcee75728dc 100644 --- a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala @@ -22,7 +22,8 @@ import scala.collection.mutable import org.scalatest.{FunSuite, PrivateMethodTester} import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler._ -import org.apache.spark.storage.BlockManagerId +import org.apache.spark.scheduler.cluster.ExecutorInfo +import org.apache.spark.util.ManualClock /** * Test add and remove behavior of ExecutorAllocationManager. @@ -32,24 +33,23 @@ class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext { import ExecutorAllocationManagerSuite._ test("verify min/max executors") { - // No min or max val conf = new SparkConf() .setMaster("local") .setAppName("test-executor-allocation-manager") .set("spark.dynamicAllocation.enabled", "true") .set("spark.dynamicAllocation.testing", "true") - intercept[SparkException] { new SparkContext(conf) } - SparkEnv.get.stop() // cleanup the created environment - SparkContext.clearActiveContext() + val sc0 = new SparkContext(conf) + assert(sc0.executorAllocationManager.isDefined) + sc0.stop() - // Only min - val conf1 = conf.clone().set("spark.dynamicAllocation.minExecutors", "1") + // Min < 0 + val conf1 = conf.clone().set("spark.dynamicAllocation.minExecutors", "-1") intercept[SparkException] { new SparkContext(conf1) } SparkEnv.get.stop() SparkContext.clearActiveContext() - // Only max - val conf2 = conf.clone().set("spark.dynamicAllocation.maxExecutors", "2") + // Max < 0 + val conf2 = conf.clone().set("spark.dynamicAllocation.maxExecutors", "-1") intercept[SparkException] { new SparkContext(conf2) } SparkEnv.get.stop() SparkContext.clearActiveContext() @@ -145,8 +145,8 @@ class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext { // Verify that running a task reduces the cap sc.listenerBus.postToAll(SparkListenerStageSubmitted(createStageInfo(1, 3))) - sc.listenerBus.postToAll(SparkListenerBlockManagerAdded( - 0L, BlockManagerId("executor-1", "host1", 1), 100L)) + sc.listenerBus.postToAll(SparkListenerExecutorAdded( + 0L, "executor-1", new ExecutorInfo("host1", 1, Map.empty))) sc.listenerBus.postToAll(SparkListenerTaskStart(1, 0, createTaskInfo(0, 0, "executor-1"))) assert(numExecutorsPending(manager) === 4) assert(addExecutors(manager) === 1) @@ -176,6 +176,33 @@ class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext { assert(numExecutorsPending(manager) === 9) } + test("cancel pending executors when no longer needed") { + sc = createSparkContext(1, 10) + val manager = sc.executorAllocationManager.get + sc.listenerBus.postToAll(SparkListenerStageSubmitted(createStageInfo(2, 5))) + + assert(numExecutorsPending(manager) === 0) + assert(numExecutorsToAdd(manager) === 1) + assert(addExecutors(manager) === 1) + assert(numExecutorsPending(manager) === 1) + assert(numExecutorsToAdd(manager) === 2) + assert(addExecutors(manager) === 2) + assert(numExecutorsPending(manager) === 3) + + val task1Info = createTaskInfo(0, 0, "executor-1") + sc.listenerBus.postToAll(SparkListenerTaskStart(2, 0, task1Info)) + + assert(numExecutorsToAdd(manager) === 4) + assert(addExecutors(manager) === 2) + + val task2Info = createTaskInfo(1, 0, "executor-1") + sc.listenerBus.postToAll(SparkListenerTaskStart(2, 0, task2Info)) + sc.listenerBus.postToAll(SparkListenerTaskEnd(2, 0, null, null, task1Info, null)) + sc.listenerBus.postToAll(SparkListenerTaskEnd(2, 0, null, null, task2Info, null)) + + assert(adjustRequestedExecutors(manager) === -1) + } + test("remove executors") { sc = createSparkContext(5, 10) val manager = sc.executorAllocationManager.get @@ -271,15 +298,15 @@ class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext { assert(removeExecutor(manager, "5")) assert(removeExecutor(manager, "6")) assert(executorIds(manager).size === 10) - assert(addExecutors(manager) === 0) // still at upper limit + assert(addExecutors(manager) === 1) onExecutorRemoved(manager, "3") onExecutorRemoved(manager, "4") assert(executorIds(manager).size === 8) // Add succeeds again, now that we are no longer at the upper limit // Number of executors added restarts at 1 - assert(addExecutors(manager) === 1) - assert(addExecutors(manager) === 1) // upper limit reached again + assert(addExecutors(manager) === 2) + assert(addExecutors(manager) === 1) // upper limit reached assert(addExecutors(manager) === 0) assert(executorIds(manager).size === 8) onExecutorRemoved(manager, "5") @@ -287,9 +314,7 @@ class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext { onExecutorAdded(manager, "13") onExecutorAdded(manager, "14") assert(executorIds(manager).size === 8) - assert(addExecutors(manager) === 1) - assert(addExecutors(manager) === 1) // upper limit reached again - assert(addExecutors(manager) === 0) + assert(addExecutors(manager) === 0) // still at upper limit onExecutorAdded(manager, "15") onExecutorAdded(manager, "16") assert(executorIds(manager).size === 10) @@ -297,7 +322,7 @@ class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext { test("starting/canceling add timer") { sc = createSparkContext(2, 10) - val clock = new TestClock(8888L) + val clock = new ManualClock(8888L) val manager = sc.executorAllocationManager.get manager.setClock(clock) @@ -306,21 +331,21 @@ class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext { onSchedulerBacklogged(manager) val firstAddTime = addTime(manager) assert(firstAddTime === clock.getTimeMillis + schedulerBacklogTimeout * 1000) - clock.tick(100L) + clock.advance(100L) onSchedulerBacklogged(manager) assert(addTime(manager) === firstAddTime) // timer is already started - clock.tick(200L) + clock.advance(200L) onSchedulerBacklogged(manager) assert(addTime(manager) === firstAddTime) onSchedulerQueueEmpty(manager) // Restart add timer - clock.tick(1000L) + clock.advance(1000L) assert(addTime(manager) === NOT_SET) onSchedulerBacklogged(manager) val secondAddTime = addTime(manager) assert(secondAddTime === clock.getTimeMillis + schedulerBacklogTimeout * 1000) - clock.tick(100L) + clock.advance(100L) onSchedulerBacklogged(manager) assert(addTime(manager) === secondAddTime) // timer is already started assert(addTime(manager) !== firstAddTime) @@ -329,7 +354,7 @@ class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext { test("starting/canceling remove timers") { sc = createSparkContext(2, 10) - val clock = new TestClock(14444L) + val clock = new ManualClock(14444L) val manager = sc.executorAllocationManager.get manager.setClock(clock) @@ -342,17 +367,17 @@ class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext { assert(removeTimes(manager).contains("1")) val firstRemoveTime = removeTimes(manager)("1") assert(firstRemoveTime === clock.getTimeMillis + executorIdleTimeout * 1000) - clock.tick(100L) + clock.advance(100L) onExecutorIdle(manager, "1") assert(removeTimes(manager)("1") === firstRemoveTime) // timer is already started - clock.tick(200L) + clock.advance(200L) onExecutorIdle(manager, "1") assert(removeTimes(manager)("1") === firstRemoveTime) - clock.tick(300L) + clock.advance(300L) onExecutorIdle(manager, "2") assert(removeTimes(manager)("2") !== firstRemoveTime) // different executor assert(removeTimes(manager)("2") === clock.getTimeMillis + executorIdleTimeout * 1000) - clock.tick(400L) + clock.advance(400L) onExecutorIdle(manager, "3") assert(removeTimes(manager)("3") !== firstRemoveTime) assert(removeTimes(manager)("3") === clock.getTimeMillis + executorIdleTimeout * 1000) @@ -361,7 +386,7 @@ class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext { assert(removeTimes(manager).contains("3")) // Restart remove timer - clock.tick(1000L) + clock.advance(1000L) onExecutorBusy(manager, "1") assert(removeTimes(manager).size === 2) onExecutorIdle(manager, "1") @@ -377,7 +402,7 @@ class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext { test("mock polling loop with no events") { sc = createSparkContext(1, 20) val manager = sc.executorAllocationManager.get - val clock = new TestClock(2020L) + val clock = new ManualClock(2020L) manager.setClock(clock) // No events - we should not be adding or removing @@ -386,15 +411,15 @@ class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext { schedule(manager) assert(numExecutorsPending(manager) === 0) assert(executorsPendingToRemove(manager).isEmpty) - clock.tick(100L) + clock.advance(100L) schedule(manager) assert(numExecutorsPending(manager) === 0) assert(executorsPendingToRemove(manager).isEmpty) - clock.tick(1000L) + clock.advance(1000L) schedule(manager) assert(numExecutorsPending(manager) === 0) assert(executorsPendingToRemove(manager).isEmpty) - clock.tick(10000L) + clock.advance(10000L) schedule(manager) assert(numExecutorsPending(manager) === 0) assert(executorsPendingToRemove(manager).isEmpty) @@ -402,57 +427,57 @@ class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext { test("mock polling loop add behavior") { sc = createSparkContext(1, 20) - val clock = new TestClock(2020L) + val clock = new ManualClock(2020L) val manager = sc.executorAllocationManager.get manager.setClock(clock) sc.listenerBus.postToAll(SparkListenerStageSubmitted(createStageInfo(0, 1000))) // Scheduler queue backlogged onSchedulerBacklogged(manager) - clock.tick(schedulerBacklogTimeout * 1000 / 2) + clock.advance(schedulerBacklogTimeout * 1000 / 2) schedule(manager) assert(numExecutorsPending(manager) === 0) // timer not exceeded yet - clock.tick(schedulerBacklogTimeout * 1000) + clock.advance(schedulerBacklogTimeout * 1000) schedule(manager) assert(numExecutorsPending(manager) === 1) // first timer exceeded - clock.tick(sustainedSchedulerBacklogTimeout * 1000 / 2) + clock.advance(sustainedSchedulerBacklogTimeout * 1000 / 2) schedule(manager) assert(numExecutorsPending(manager) === 1) // second timer not exceeded yet - clock.tick(sustainedSchedulerBacklogTimeout * 1000) + clock.advance(sustainedSchedulerBacklogTimeout * 1000) schedule(manager) assert(numExecutorsPending(manager) === 1 + 2) // second timer exceeded - clock.tick(sustainedSchedulerBacklogTimeout * 1000) + clock.advance(sustainedSchedulerBacklogTimeout * 1000) schedule(manager) assert(numExecutorsPending(manager) === 1 + 2 + 4) // third timer exceeded // Scheduler queue drained onSchedulerQueueEmpty(manager) - clock.tick(sustainedSchedulerBacklogTimeout * 1000) + clock.advance(sustainedSchedulerBacklogTimeout * 1000) schedule(manager) assert(numExecutorsPending(manager) === 7) // timer is canceled - clock.tick(sustainedSchedulerBacklogTimeout * 1000) + clock.advance(sustainedSchedulerBacklogTimeout * 1000) schedule(manager) assert(numExecutorsPending(manager) === 7) // Scheduler queue backlogged again onSchedulerBacklogged(manager) - clock.tick(schedulerBacklogTimeout * 1000) + clock.advance(schedulerBacklogTimeout * 1000) schedule(manager) assert(numExecutorsPending(manager) === 7 + 1) // timer restarted - clock.tick(sustainedSchedulerBacklogTimeout * 1000) + clock.advance(sustainedSchedulerBacklogTimeout * 1000) schedule(manager) assert(numExecutorsPending(manager) === 7 + 1 + 2) - clock.tick(sustainedSchedulerBacklogTimeout * 1000) + clock.advance(sustainedSchedulerBacklogTimeout * 1000) schedule(manager) assert(numExecutorsPending(manager) === 7 + 1 + 2 + 4) - clock.tick(sustainedSchedulerBacklogTimeout * 1000) + clock.advance(sustainedSchedulerBacklogTimeout * 1000) schedule(manager) assert(numExecutorsPending(manager) === 20) // limit reached } test("mock polling loop remove behavior") { sc = createSparkContext(1, 20) - val clock = new TestClock(2020L) + val clock = new ManualClock(2020L) val manager = sc.executorAllocationManager.get manager.setClock(clock) @@ -462,11 +487,11 @@ class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext { onExecutorAdded(manager, "executor-3") assert(removeTimes(manager).size === 3) assert(executorsPendingToRemove(manager).isEmpty) - clock.tick(executorIdleTimeout * 1000 / 2) + clock.advance(executorIdleTimeout * 1000 / 2) schedule(manager) assert(removeTimes(manager).size === 3) // idle threshold not reached yet assert(executorsPendingToRemove(manager).isEmpty) - clock.tick(executorIdleTimeout * 1000) + clock.advance(executorIdleTimeout * 1000) schedule(manager) assert(removeTimes(manager).isEmpty) // idle threshold exceeded assert(executorsPendingToRemove(manager).size === 2) // limit reached (1 executor remaining) @@ -487,7 +512,7 @@ class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext { assert(!removeTimes(manager).contains("executor-5")) assert(!removeTimes(manager).contains("executor-6")) assert(executorsPendingToRemove(manager).size === 2) - clock.tick(executorIdleTimeout * 1000) + clock.advance(executorIdleTimeout * 1000) schedule(manager) assert(removeTimes(manager).isEmpty) // idle executors are removed assert(executorsPendingToRemove(manager).size === 4) @@ -505,7 +530,7 @@ class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext { assert(removeTimes(manager).contains("executor-5")) assert(removeTimes(manager).contains("executor-6")) assert(executorsPendingToRemove(manager).size === 4) - clock.tick(executorIdleTimeout * 1000) + clock.advance(executorIdleTimeout * 1000) schedule(manager) assert(removeTimes(manager).isEmpty) assert(executorsPendingToRemove(manager).size === 6) // limit reached (1 executor remaining) @@ -579,30 +604,28 @@ class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext { assert(removeTimes(manager).isEmpty) // New executors have registered - sc.listenerBus.postToAll(SparkListenerBlockManagerAdded( - 0L, BlockManagerId("executor-1", "host1", 1), 100L)) + sc.listenerBus.postToAll(SparkListenerExecutorAdded( + 0L, "executor-1", new ExecutorInfo("host1", 1, Map.empty))) assert(executorIds(manager).size === 1) assert(executorIds(manager).contains("executor-1")) assert(removeTimes(manager).size === 1) assert(removeTimes(manager).contains("executor-1")) - sc.listenerBus.postToAll(SparkListenerBlockManagerAdded( - 0L, BlockManagerId("executor-2", "host2", 1), 100L)) + sc.listenerBus.postToAll(SparkListenerExecutorAdded( + 0L, "executor-2", new ExecutorInfo("host2", 1, Map.empty))) assert(executorIds(manager).size === 2) assert(executorIds(manager).contains("executor-2")) assert(removeTimes(manager).size === 2) assert(removeTimes(manager).contains("executor-2")) // Existing executors have disconnected - sc.listenerBus.postToAll(SparkListenerBlockManagerRemoved( - 0L, BlockManagerId("executor-1", "host1", 1))) + sc.listenerBus.postToAll(SparkListenerExecutorRemoved(0L, "executor-1", "")) assert(executorIds(manager).size === 1) assert(!executorIds(manager).contains("executor-1")) assert(removeTimes(manager).size === 1) assert(!removeTimes(manager).contains("executor-1")) // Unknown executor has disconnected - sc.listenerBus.postToAll(SparkListenerBlockManagerRemoved( - 0L, BlockManagerId("executor-3", "host3", 1))) + sc.listenerBus.postToAll(SparkListenerExecutorRemoved(0L, "executor-3", "")) assert(executorIds(manager).size === 1) assert(removeTimes(manager).size === 1) } @@ -614,8 +637,8 @@ class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext { assert(removeTimes(manager).isEmpty) sc.listenerBus.postToAll(SparkListenerTaskStart(0, 0, createTaskInfo(0, 0, "executor-1"))) - sc.listenerBus.postToAll(SparkListenerBlockManagerAdded( - 0L, BlockManagerId("executor-1", "host1", 1), 100L)) + sc.listenerBus.postToAll(SparkListenerExecutorAdded( + 0L, "executor-1", new ExecutorInfo("host1", 1, Map.empty))) assert(executorIds(manager).size === 1) assert(executorIds(manager).contains("executor-1")) assert(removeTimes(manager).size === 0) @@ -626,16 +649,16 @@ class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext { val manager = sc.executorAllocationManager.get assert(executorIds(manager).isEmpty) assert(removeTimes(manager).isEmpty) - sc.listenerBus.postToAll(SparkListenerBlockManagerAdded( - 0L, BlockManagerId("executor-1", "host1", 1), 100L)) + sc.listenerBus.postToAll(SparkListenerExecutorAdded( + 0L, "executor-1", new ExecutorInfo("host1", 1, Map.empty))) sc.listenerBus.postToAll(SparkListenerTaskStart(0, 0, createTaskInfo(0, 0, "executor-1"))) assert(executorIds(manager).size === 1) assert(executorIds(manager).contains("executor-1")) assert(removeTimes(manager).size === 0) - sc.listenerBus.postToAll(SparkListenerBlockManagerAdded( - 0L, BlockManagerId("executor-2", "host1", 1), 100L)) + sc.listenerBus.postToAll(SparkListenerExecutorAdded( + 0L, "executor-2", new ExecutorInfo("host1", 1, Map.empty))) assert(executorIds(manager).size === 2) assert(executorIds(manager).contains("executor-2")) assert(removeTimes(manager).size === 1) @@ -682,6 +705,7 @@ private object ExecutorAllocationManagerSuite extends PrivateMethodTester { private val _numExecutorsToAdd = PrivateMethod[Int]('numExecutorsToAdd) private val _numExecutorsPending = PrivateMethod[Int]('numExecutorsPending) + private val _maxNumExecutorsNeeded = PrivateMethod[Int]('maxNumExecutorsNeeded) private val _executorsPendingToRemove = PrivateMethod[collection.Set[String]]('executorsPendingToRemove) private val _executorIds = PrivateMethod[collection.Set[String]]('executorIds) @@ -689,6 +713,7 @@ private object ExecutorAllocationManagerSuite extends PrivateMethodTester { private val _removeTimes = PrivateMethod[collection.Map[String, Long]]('removeTimes) private val _schedule = PrivateMethod[Unit]('schedule) private val _addExecutors = PrivateMethod[Int]('addExecutors) + private val _addOrCancelExecutorRequests = PrivateMethod[Int]('addOrCancelExecutorRequests) private val _removeExecutor = PrivateMethod[Boolean]('removeExecutor) private val _onExecutorAdded = PrivateMethod[Unit]('onExecutorAdded) private val _onExecutorRemoved = PrivateMethod[Unit]('onExecutorRemoved) @@ -727,7 +752,12 @@ private object ExecutorAllocationManagerSuite extends PrivateMethodTester { } private def addExecutors(manager: ExecutorAllocationManager): Int = { - manager invokePrivate _addExecutors() + val maxNumExecutorsNeeded = manager invokePrivate _maxNumExecutorsNeeded() + manager invokePrivate _addExecutors(maxNumExecutorsNeeded) + } + + private def adjustRequestedExecutors(manager: ExecutorAllocationManager): Int = { + manager invokePrivate _addOrCancelExecutorRequests(0L) } private def removeExecutor(manager: ExecutorAllocationManager, id: String): Boolean = { diff --git a/core/src/test/scala/org/apache/spark/FileServerSuite.scala b/core/src/test/scala/org/apache/spark/FileServerSuite.scala index 0f49ce4754fbb..5fdf6bc2777e3 100644 --- a/core/src/test/scala/org/apache/spark/FileServerSuite.scala +++ b/core/src/test/scala/org/apache/spark/FileServerSuite.scala @@ -18,13 +18,19 @@ package org.apache.spark import java.io._ +import java.net.URI import java.util.jar.{JarEntry, JarOutputStream} +import javax.net.ssl.SSLHandshakeException import com.google.common.io.ByteStreams +import org.apache.commons.io.{FileUtils, IOUtils} +import org.apache.commons.lang3.RandomUtils import org.scalatest.FunSuite import org.apache.spark.util.Utils +import SSLSampleConfigs._ + class FileServerSuite extends FunSuite with LocalSparkContext { @transient var tmpDir: File = _ @@ -168,4 +174,88 @@ class FileServerSuite extends FunSuite with LocalSparkContext { } } + test ("HttpFileServer should work with SSL") { + val sparkConf = sparkSSLConfig() + val sm = new SecurityManager(sparkConf) + val server = new HttpFileServer(sparkConf, sm, 0) + try { + server.initialize() + + fileTransferTest(server, sm) + } finally { + server.stop() + } + } + + test ("HttpFileServer should work with SSL and good credentials") { + val sparkConf = sparkSSLConfig() + sparkConf.set("spark.authenticate", "true") + sparkConf.set("spark.authenticate.secret", "good") + + val sm = new SecurityManager(sparkConf) + val server = new HttpFileServer(sparkConf, sm, 0) + try { + server.initialize() + + fileTransferTest(server, sm) + } finally { + server.stop() + } + } + + test ("HttpFileServer should not work with valid SSL and bad credentials") { + val sparkConf = sparkSSLConfig() + sparkConf.set("spark.authenticate", "true") + sparkConf.set("spark.authenticate.secret", "bad") + + val sm = new SecurityManager(sparkConf) + val server = new HttpFileServer(sparkConf, sm, 0) + try { + server.initialize() + + intercept[IOException] { + fileTransferTest(server) + } + } finally { + server.stop() + } + } + + test ("HttpFileServer should not work with SSL when the server is untrusted") { + val sparkConf = sparkSSLConfigUntrusted() + val sm = new SecurityManager(sparkConf) + val server = new HttpFileServer(sparkConf, sm, 0) + try { + server.initialize() + + intercept[SSLHandshakeException] { + fileTransferTest(server) + } + } finally { + server.stop() + } + } + + def fileTransferTest(server: HttpFileServer, sm: SecurityManager = null): Unit = { + val randomContent = RandomUtils.nextBytes(100) + val file = File.createTempFile("FileServerSuite", "sslTests", tmpDir) + FileUtils.writeByteArrayToFile(file, randomContent) + server.addFile(file) + + val uri = new URI(server.serverUri + "/files/" + file.getName) + + val connection = if (sm != null && sm.isAuthenticationEnabled()) { + Utils.constructURIForAuthentication(uri, sm).toURL.openConnection() + } else { + uri.toURL.openConnection() + } + + if (sm != null) { + Utils.setupSecureURLConnection(connection, sm) + } + + val buf = IOUtils.toByteArray(connection.getInputStream) + assert(buf === randomContent) + } + } diff --git a/core/src/test/scala/org/apache/spark/FileSuite.scala b/core/src/test/scala/org/apache/spark/FileSuite.scala index 5e24196101fbc..7acd27c735727 100644 --- a/core/src/test/scala/org/apache/spark/FileSuite.scala +++ b/core/src/test/scala/org/apache/spark/FileSuite.scala @@ -32,7 +32,6 @@ import org.apache.hadoop.mapreduce.lib.input.{FileSplit => NewFileSplit, TextInp import org.apache.hadoop.mapreduce.lib.output.{TextOutputFormat => NewTextOutputFormat} import org.scalatest.FunSuite -import org.apache.spark.SparkContext._ import org.apache.spark.rdd.{NewHadoopRDD, HadoopRDD} import org.apache.spark.util.Utils diff --git a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala index 7584ae79fc920..21487bc24d58a 100644 --- a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala +++ b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala @@ -171,11 +171,11 @@ class JobCancellationSuite extends FunSuite with Matchers with BeforeAndAfter assert(jobB.get() === 100) } - ignore("two jobs sharing the same stage") { + test("two jobs sharing the same stage") { // sem1: make sure cancel is issued after some tasks are launched - // sem2: make sure the first stage is not finished until cancel is issued + // twoJobsSharingStageSemaphore: + // make sure the first stage is not finished until cancel is issued val sem1 = new Semaphore(0) - val sem2 = new Semaphore(0) sc = new SparkContext("local[2]", "test") sc.addSparkListener(new SparkListener { @@ -186,7 +186,7 @@ class JobCancellationSuite extends FunSuite with Matchers with BeforeAndAfter // Create two actions that would share the some stages. val rdd = sc.parallelize(1 to 10, 2).map { i => - sem2.acquire() + JobCancellationSuite.twoJobsSharingStageSemaphore.acquire() (i, i) }.reduceByKey(_+_) val f1 = rdd.collectAsync() @@ -196,13 +196,13 @@ class JobCancellationSuite extends FunSuite with Matchers with BeforeAndAfter future { sem1.acquire() f1.cancel() - sem2.release(10) + JobCancellationSuite.twoJobsSharingStageSemaphore.release(10) } - // Expect both to fail now. - // TODO: update this test when we change Spark so cancelling f1 wouldn't affect f2. + // Expect f1 to fail due to cancellation, intercept[SparkException] { f1.get() } - intercept[SparkException] { f2.get() } + // but f2 should not be affected + f2.get() } def testCount() { @@ -268,4 +268,5 @@ class JobCancellationSuite extends FunSuite with Matchers with BeforeAndAfter object JobCancellationSuite { val taskStartedSemaphore = new Semaphore(0) val taskCancelledSemaphore = new Semaphore(0) + val twoJobsSharingStageSemaphore = new Semaphore(0) } diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index d27880f4bc32f..ccfe0678cb1c3 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -120,7 +120,7 @@ class MapOutputTrackerSuite extends FunSuite { securityManager = new SecurityManager(conf)) val slaveTracker = new MapOutputTrackerWorker(conf) val selection = slaveSystem.actorSelection( - s"akka.tcp://spark@localhost:$boundPort/user/MapOutputTracker") + AkkaUtils.address(AkkaUtils.protocol(slaveSystem), "spark", "localhost", boundPort, "MapOutputTracker")) val timeout = AkkaUtils.lookupTimeout(conf) slaveTracker.trackerActor = Await.result(selection.resolveOne(timeout), timeout) diff --git a/core/src/test/scala/org/apache/spark/SSLOptionsSuite.scala b/core/src/test/scala/org/apache/spark/SSLOptionsSuite.scala new file mode 100644 index 0000000000000..444a33371bd71 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/SSLOptionsSuite.scala @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import java.io.File + +import com.google.common.io.Files +import org.apache.spark.util.Utils +import org.scalatest.{BeforeAndAfterAll, FunSuite} + +class SSLOptionsSuite extends FunSuite with BeforeAndAfterAll { + + test("test resolving property file as spark conf ") { + val keyStorePath = new File(this.getClass.getResource("/keystore").toURI).getAbsolutePath + val trustStorePath = new File(this.getClass.getResource("/truststore").toURI).getAbsolutePath + + val conf = new SparkConf + conf.set("spark.ssl.enabled", "true") + conf.set("spark.ssl.keyStore", keyStorePath) + conf.set("spark.ssl.keyStorePassword", "password") + conf.set("spark.ssl.keyPassword", "password") + conf.set("spark.ssl.trustStore", trustStorePath) + conf.set("spark.ssl.trustStorePassword", "password") + conf.set("spark.ssl.enabledAlgorithms", "TLS_RSA_WITH_AES_128_CBC_SHA, TLS_RSA_WITH_AES_256_CBC_SHA") + conf.set("spark.ssl.protocol", "SSLv3") + + val opts = SSLOptions.parse(conf, "spark.ssl") + + assert(opts.enabled === true) + assert(opts.trustStore.isDefined === true) + assert(opts.trustStore.get.getName === "truststore") + assert(opts.trustStore.get.getAbsolutePath === trustStorePath) + assert(opts.keyStore.isDefined === true) + assert(opts.keyStore.get.getName === "keystore") + assert(opts.keyStore.get.getAbsolutePath === keyStorePath) + assert(opts.trustStorePassword === Some("password")) + assert(opts.keyStorePassword === Some("password")) + assert(opts.keyPassword === Some("password")) + assert(opts.protocol === Some("SSLv3")) + assert(opts.enabledAlgorithms === Set("TLS_RSA_WITH_AES_128_CBC_SHA", "TLS_RSA_WITH_AES_256_CBC_SHA")) + } + + test("test resolving property with defaults specified ") { + val keyStorePath = new File(this.getClass.getResource("/keystore").toURI).getAbsolutePath + val trustStorePath = new File(this.getClass.getResource("/truststore").toURI).getAbsolutePath + + val conf = new SparkConf + conf.set("spark.ssl.enabled", "true") + conf.set("spark.ssl.keyStore", keyStorePath) + conf.set("spark.ssl.keyStorePassword", "password") + conf.set("spark.ssl.keyPassword", "password") + conf.set("spark.ssl.trustStore", trustStorePath) + conf.set("spark.ssl.trustStorePassword", "password") + conf.set("spark.ssl.enabledAlgorithms", "TLS_RSA_WITH_AES_128_CBC_SHA, TLS_RSA_WITH_AES_256_CBC_SHA") + conf.set("spark.ssl.protocol", "SSLv3") + + val defaultOpts = SSLOptions.parse(conf, "spark.ssl", defaults = None) + val opts = SSLOptions.parse(conf, "spark.ui.ssl", defaults = Some(defaultOpts)) + + assert(opts.enabled === true) + assert(opts.trustStore.isDefined === true) + assert(opts.trustStore.get.getName === "truststore") + assert(opts.trustStore.get.getAbsolutePath === trustStorePath) + assert(opts.keyStore.isDefined === true) + assert(opts.keyStore.get.getName === "keystore") + assert(opts.keyStore.get.getAbsolutePath === keyStorePath) + assert(opts.trustStorePassword === Some("password")) + assert(opts.keyStorePassword === Some("password")) + assert(opts.keyPassword === Some("password")) + assert(opts.protocol === Some("SSLv3")) + assert(opts.enabledAlgorithms === Set("TLS_RSA_WITH_AES_128_CBC_SHA", "TLS_RSA_WITH_AES_256_CBC_SHA")) + } + + test("test whether defaults can be overridden ") { + val keyStorePath = new File(this.getClass.getResource("/keystore").toURI).getAbsolutePath + val trustStorePath = new File(this.getClass.getResource("/truststore").toURI).getAbsolutePath + + val conf = new SparkConf + conf.set("spark.ssl.enabled", "true") + conf.set("spark.ui.ssl.enabled", "false") + conf.set("spark.ssl.keyStore", keyStorePath) + conf.set("spark.ssl.keyStorePassword", "password") + conf.set("spark.ui.ssl.keyStorePassword", "12345") + conf.set("spark.ssl.keyPassword", "password") + conf.set("spark.ssl.trustStore", trustStorePath) + conf.set("spark.ssl.trustStorePassword", "password") + conf.set("spark.ssl.enabledAlgorithms", "TLS_RSA_WITH_AES_128_CBC_SHA, TLS_RSA_WITH_AES_256_CBC_SHA") + conf.set("spark.ui.ssl.enabledAlgorithms", "ABC, DEF") + conf.set("spark.ssl.protocol", "SSLv3") + + val defaultOpts = SSLOptions.parse(conf, "spark.ssl", defaults = None) + val opts = SSLOptions.parse(conf, "spark.ui.ssl", defaults = Some(defaultOpts)) + + assert(opts.enabled === false) + assert(opts.trustStore.isDefined === true) + assert(opts.trustStore.get.getName === "truststore") + assert(opts.trustStore.get.getAbsolutePath === trustStorePath) + assert(opts.keyStore.isDefined === true) + assert(opts.keyStore.get.getName === "keystore") + assert(opts.keyStore.get.getAbsolutePath === keyStorePath) + assert(opts.trustStorePassword === Some("password")) + assert(opts.keyStorePassword === Some("12345")) + assert(opts.keyPassword === Some("password")) + assert(opts.protocol === Some("SSLv3")) + assert(opts.enabledAlgorithms === Set("ABC", "DEF")) + } + +} diff --git a/core/src/test/scala/org/apache/spark/SSLSampleConfigs.scala b/core/src/test/scala/org/apache/spark/SSLSampleConfigs.scala new file mode 100644 index 0000000000000..ace8123a8961f --- /dev/null +++ b/core/src/test/scala/org/apache/spark/SSLSampleConfigs.scala @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import java.io.File + +object SSLSampleConfigs { + val keyStorePath = new File(this.getClass.getResource("/keystore").toURI).getAbsolutePath + val untrustedKeyStorePath = new File(this.getClass.getResource("/untrusted-keystore").toURI).getAbsolutePath + val trustStorePath = new File(this.getClass.getResource("/truststore").toURI).getAbsolutePath + + def sparkSSLConfig() = { + val conf = new SparkConf(loadDefaults = false) + conf.set("spark.ssl.enabled", "true") + conf.set("spark.ssl.keyStore", keyStorePath) + conf.set("spark.ssl.keyStorePassword", "password") + conf.set("spark.ssl.keyPassword", "password") + conf.set("spark.ssl.trustStore", trustStorePath) + conf.set("spark.ssl.trustStorePassword", "password") + conf.set("spark.ssl.enabledAlgorithms", + "TLS_RSA_WITH_AES_128_CBC_SHA, SSL_RSA_WITH_DES_CBC_SHA") + conf.set("spark.ssl.protocol", "TLSv1") + conf + } + + def sparkSSLConfigUntrusted() = { + val conf = new SparkConf(loadDefaults = false) + conf.set("spark.ssl.enabled", "true") + conf.set("spark.ssl.keyStore", untrustedKeyStorePath) + conf.set("spark.ssl.keyStorePassword", "password") + conf.set("spark.ssl.keyPassword", "password") + conf.set("spark.ssl.trustStore", trustStorePath) + conf.set("spark.ssl.trustStorePassword", "password") + conf.set("spark.ssl.enabledAlgorithms", + "TLS_RSA_WITH_AES_128_CBC_SHA, SSL_RSA_WITH_DES_CBC_SHA") + conf.set("spark.ssl.protocol", "TLSv1") + conf + } + +} diff --git a/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala b/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala index fcca0867b8072..43fbd3ff3f756 100644 --- a/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark -import scala.collection.mutable.ArrayBuffer +import java.io.File import org.scalatest.FunSuite @@ -125,6 +125,54 @@ class SecurityManagerSuite extends FunSuite { } + test("ssl on setup") { + val conf = SSLSampleConfigs.sparkSSLConfig() + + val securityManager = new SecurityManager(conf) + + assert(securityManager.fileServerSSLOptions.enabled === true) + assert(securityManager.akkaSSLOptions.enabled === true) + + assert(securityManager.sslSocketFactory.isDefined === true) + assert(securityManager.hostnameVerifier.isDefined === true) + + assert(securityManager.fileServerSSLOptions.trustStore.isDefined === true) + assert(securityManager.fileServerSSLOptions.trustStore.get.getName === "truststore") + assert(securityManager.fileServerSSLOptions.keyStore.isDefined === true) + assert(securityManager.fileServerSSLOptions.keyStore.get.getName === "keystore") + assert(securityManager.fileServerSSLOptions.trustStorePassword === Some("password")) + assert(securityManager.fileServerSSLOptions.keyStorePassword === Some("password")) + assert(securityManager.fileServerSSLOptions.keyPassword === Some("password")) + assert(securityManager.fileServerSSLOptions.protocol === Some("TLSv1")) + assert(securityManager.fileServerSSLOptions.enabledAlgorithms === + Set("TLS_RSA_WITH_AES_128_CBC_SHA", "SSL_RSA_WITH_DES_CBC_SHA")) + + assert(securityManager.akkaSSLOptions.trustStore.isDefined === true) + assert(securityManager.akkaSSLOptions.trustStore.get.getName === "truststore") + assert(securityManager.akkaSSLOptions.keyStore.isDefined === true) + assert(securityManager.akkaSSLOptions.keyStore.get.getName === "keystore") + assert(securityManager.akkaSSLOptions.trustStorePassword === Some("password")) + assert(securityManager.akkaSSLOptions.keyStorePassword === Some("password")) + assert(securityManager.akkaSSLOptions.keyPassword === Some("password")) + assert(securityManager.akkaSSLOptions.protocol === Some("TLSv1")) + assert(securityManager.akkaSSLOptions.enabledAlgorithms === + Set("TLS_RSA_WITH_AES_128_CBC_SHA", "SSL_RSA_WITH_DES_CBC_SHA")) + } + + test("ssl off setup") { + val file = File.createTempFile("SSLOptionsSuite", "conf") + file.deleteOnExit() + + System.setProperty("spark.ssl.configFile", file.getAbsolutePath) + val conf = new SparkConf() + + val securityManager = new SecurityManager(conf) + + assert(securityManager.fileServerSSLOptions.enabled === false) + assert(securityManager.akkaSSLOptions.enabled === false) + assert(securityManager.sslSocketFactory.isDefined === false) + assert(securityManager.hostnameVerifier.isDefined === false) + } } diff --git a/core/src/test/scala/org/apache/spark/SparkConfSuite.scala b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala index 790976a5ac308..ea6b73bc68b34 100644 --- a/core/src/test/scala/org/apache/spark/SparkConfSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala @@ -17,6 +17,10 @@ package org.apache.spark +import java.util.concurrent.{TimeUnit, Executors} + +import scala.util.{Try, Random} + import org.scalatest.FunSuite import org.apache.spark.serializer.{KryoRegistrator, KryoSerializer} import org.apache.spark.util.ResetSystemProperties @@ -123,6 +127,27 @@ class SparkConfSuite extends FunSuite with LocalSparkContext with ResetSystemPro assert(conf.get("spark.test.a.b.c") === "a.b.c") } + test("Thread safeness - SPARK-5425") { + import scala.collection.JavaConversions._ + val executor = Executors.newSingleThreadScheduledExecutor() + val sf = executor.scheduleAtFixedRate(new Runnable { + override def run(): Unit = + System.setProperty("spark.5425." + Random.nextInt(), Random.nextInt().toString) + }, 0, 1, TimeUnit.MILLISECONDS) + + try { + val t0 = System.currentTimeMillis() + while ((System.currentTimeMillis() - t0) < 1000) { + val conf = Try(new SparkConf(loadDefaults = true)) + assert(conf.isSuccess === true) + } + } finally { + executor.shutdownNow() + for (key <- System.getProperties.stringPropertyNames() if key.startsWith("spark.5425.")) + System.getProperties.remove(key) + } + } + test("register kryo classes through registerKryoClasses") { val conf = new SparkConf().set("spark.kryo.registrationRequired", "true") @@ -172,6 +197,18 @@ class SparkConfSuite extends FunSuite with LocalSparkContext with ResetSystemPro serializer.newInstance().serialize(new StringBuffer()) } + test("deprecated config keys") { + val conf = new SparkConf() + .set("spark.files.userClassPathFirst", "true") + .set("spark.yarn.user.classpath.first", "true") + assert(conf.contains("spark.files.userClassPathFirst")) + assert(conf.contains("spark.executor.userClassPathFirst")) + assert(conf.contains("spark.yarn.user.classpath.first")) + assert(conf.getBoolean("spark.files.userClassPathFirst", false)) + assert(conf.getBoolean("spark.executor.userClassPathFirst", false)) + assert(conf.getBoolean("spark.yarn.user.classpath.first", false)) + } + } class Class1 {} diff --git a/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala index 8ae4f243ec1ae..bbed8ddc6bafc 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala @@ -149,7 +149,7 @@ class SparkContextSchedulerCreationSuite } test("yarn-client") { - testYarn("yarn-client", "org.apache.spark.scheduler.cluster.YarnClientClusterScheduler") + testYarn("yarn-client", "org.apache.spark.scheduler.cluster.YarnScheduler") } def testMesos(master: String, expectedClass: Class[_], coarse: Boolean) { diff --git a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala index 8b3c6871a7b39..50f347f1954de 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala @@ -17,10 +17,17 @@ package org.apache.spark +import java.io.File + +import com.google.common.base.Charsets._ +import com.google.common.io.Files + import org.scalatest.FunSuite import org.apache.hadoop.io.BytesWritable +import org.apache.spark.util.Utils + class SparkContextSuite extends FunSuite with LocalSparkContext { test("Only one SparkContext may be active at a time") { @@ -72,4 +79,74 @@ class SparkContextSuite extends FunSuite with LocalSparkContext { val byteArray2 = converter.convert(bytesWritable) assert(byteArray2.length === 0) } + + test("addFile works") { + val file = File.createTempFile("someprefix", "somesuffix") + val absolutePath = file.getAbsolutePath + try { + Files.write("somewords", file, UTF_8) + val length = file.length() + sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local")) + sc.addFile(file.getAbsolutePath) + sc.parallelize(Array(1), 1).map(x => { + val gotten = new File(SparkFiles.get(file.getName)) + if (!gotten.exists()) { + throw new SparkException("file doesn't exist") + } + if (length != gotten.length()) { + throw new SparkException( + s"file has different length $length than added file ${gotten.length()}") + } + if (absolutePath == gotten.getAbsolutePath) { + throw new SparkException("file should have been copied") + } + x + }).count() + } finally { + sc.stop() + } + } + + test("addFile recursive works") { + val pluto = Utils.createTempDir() + val neptune = Utils.createTempDir(pluto.getAbsolutePath) + val saturn = Utils.createTempDir(neptune.getAbsolutePath) + val alien1 = File.createTempFile("alien", "1", neptune) + val alien2 = File.createTempFile("alien", "2", saturn) + + try { + sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local")) + sc.addFile(neptune.getAbsolutePath, true) + sc.parallelize(Array(1), 1).map(x => { + val sep = File.separator + if (!new File(SparkFiles.get(neptune.getName + sep + alien1.getName)).exists()) { + throw new SparkException("can't access file under root added directory") + } + if (!new File(SparkFiles.get(neptune.getName + sep + saturn.getName + sep + alien2.getName)) + .exists()) { + throw new SparkException("can't access file in nested directory") + } + if (new File(SparkFiles.get(pluto.getName + sep + neptune.getName + sep + alien1.getName)) + .exists()) { + throw new SparkException("file exists that shouldn't") + } + x + }).count() + } finally { + sc.stop() + } + } + + test("addFile recursive can't add directories by default") { + val dir = Utils.createTempDir() + + try { + sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local")) + intercept[SparkException] { + sc.addFile(dir.getAbsolutePath) + } + } finally { + sc.stop() + } + } } diff --git a/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala b/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala index 7b866f08a0e9f..c63d834f9048b 100644 --- a/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala @@ -23,11 +23,22 @@ import org.scalatest.FunSuite class PythonRDDSuite extends FunSuite { - test("Writing large strings to the worker") { - val input: List[String] = List("a"*100000) - val buffer = new DataOutputStream(new ByteArrayOutputStream) - PythonRDD.writeIteratorToStream(input.iterator, buffer) - } + test("Writing large strings to the worker") { + val input: List[String] = List("a"*100000) + val buffer = new DataOutputStream(new ByteArrayOutputStream) + PythonRDD.writeIteratorToStream(input.iterator, buffer) + } + test("Handle nulls gracefully") { + val buffer = new DataOutputStream(new ByteArrayOutputStream) + // Should not have NPE when write an Iterator with null in it + // The correctness will be tested in Python + PythonRDD.writeIteratorToStream(Iterator("a", null), buffer) + PythonRDD.writeIteratorToStream(Iterator(null, "a"), buffer) + PythonRDD.writeIteratorToStream(Iterator("a".getBytes, null), buffer) + PythonRDD.writeIteratorToStream(Iterator(null, "a".getBytes), buffer) + PythonRDD.writeIteratorToStream(Iterator((null, null), ("a", null), (null, "b")), buffer) + PythonRDD.writeIteratorToStream( + Iterator((null, null), ("a".getBytes, null), (null, "b".getBytes)), buffer) + } } - diff --git a/core/src/test/scala/org/apache/spark/api/python/SerDeUtilSuite.scala b/core/src/test/scala/org/apache/spark/api/python/SerDeUtilSuite.scala new file mode 100644 index 0000000000000..f8c39326145e1 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/api/python/SerDeUtilSuite.scala @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.python + +import org.scalatest.FunSuite + +import org.apache.spark.SharedSparkContext + +class SerDeUtilSuite extends FunSuite with SharedSparkContext { + + test("Converting an empty pair RDD to python does not throw an exception (SPARK-5441)") { + val emptyRdd = sc.makeRDD(Seq[(Any, Any)]()) + SerDeUtil.pairRDDToPython(emptyRdd, 10) + } + + test("Converting an empty python RDD to pair RDD does not throw an exception (SPARK-5441)") { + val emptyRdd = sc.makeRDD(Seq[(Any, Any)]()) + val javaRdd = emptyRdd.toJavaRDD() + val pythonRdd = SerDeUtil.javaToPython(javaRdd) + SerDeUtil.pythonToPairRDD(pythonRdd, false) + } +} + diff --git a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala index b0a70f012f1f3..af3272692d7a1 100644 --- a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala +++ b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala @@ -170,6 +170,15 @@ class BroadcastSuite extends FunSuite with LocalSparkContext { testPackage.runCallSiteTest(sc) } + test("Broadcast variables cannot be created after SparkContext is stopped (SPARK-5065)") { + sc = new SparkContext("local", "test") + sc.stop() + val thrown = intercept[IllegalStateException] { + sc.broadcast(Seq(1, 2, 3)) + } + assert(thrown.getMessage.toLowerCase.contains("stopped")) + } + /** * Verify the persistence of state associated with an HttpBroadcast in either local mode or * local-cluster mode (when distributed = true). @@ -349,8 +358,7 @@ class BroadcastSuite extends FunSuite with LocalSparkContext { package object testPackage extends Assertions { def runCallSiteTest(sc: SparkContext) { - val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2) - val broadcast = sc.broadcast(rdd) + val broadcast = sc.broadcast(Array(1, 2, 3, 4)) broadcast.destroy() val thrown = intercept[SparkException] { broadcast.value } assert(thrown.getMessage.contains("BroadcastSuite.scala")) diff --git a/core/src/test/scala/org/apache/spark/deploy/ClientSuite.scala b/core/src/test/scala/org/apache/spark/deploy/ClientSuite.scala index d2dae34be7bfb..518073dcbb64e 100644 --- a/core/src/test/scala/org/apache/spark/deploy/ClientSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/ClientSuite.scala @@ -23,6 +23,7 @@ import org.scalatest.Matchers class ClientSuite extends FunSuite with Matchers { test("correctly validates driver jar URL's") { ClientArguments.isValidJarUrl("http://someHost:8080/foo.jar") should be (true) + ClientArguments.isValidJarUrl("https://someHost:8080/foo.jar") should be (true) // file scheme with authority and path is valid. ClientArguments.isValidJarUrl("file://somehost/path/to/a/jarFile.jar") should be (true) diff --git a/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala index aa65f7e8915e6..e955636cf5b59 100644 --- a/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala @@ -68,7 +68,8 @@ class JsonProtocolSuite extends FunSuite { val completedApps = Array[ApplicationInfo]() val activeDrivers = Array(createDriverInfo()) val completedDrivers = Array(createDriverInfo()) - val stateResponse = new MasterStateResponse("host", 8080, workers, activeApps, completedApps, + val stateResponse = new MasterStateResponse( + "host", 8080, None, workers, activeApps, completedApps, activeDrivers, completedDrivers, RecoveryState.ALIVE) val output = JsonProtocol.writeMasterState(stateResponse) assertValidJson(output) @@ -117,7 +118,7 @@ class JsonProtocolSuite extends FunSuite { } def createExecutorRunner(): ExecutorRunner = { - new ExecutorRunner("appId", 123, createAppDesc(), 4, 1234, null, "workerId", "host", + new ExecutorRunner("appId", 123, createAppDesc(), 4, 1234, null, "workerId", "host", 123, new File("sparkHome"), new File("workDir"), "akka://worker", new SparkConf, Seq("localDir"), ExecutorState.RUNNING) } diff --git a/core/src/test/scala/org/apache/spark/deploy/LogUrlsStandaloneSuite.scala b/core/src/test/scala/org/apache/spark/deploy/LogUrlsStandaloneSuite.scala new file mode 100644 index 0000000000000..f33bdc73e40ac --- /dev/null +++ b/core/src/test/scala/org/apache/spark/deploy/LogUrlsStandaloneSuite.scala @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy + +import scala.collection.mutable + +import org.scalatest.{BeforeAndAfter, FunSuite} + +import org.apache.spark.scheduler.cluster.ExecutorInfo +import org.apache.spark.scheduler.{SparkListenerExecutorAdded, SparkListener} +import org.apache.spark.{SparkContext, LocalSparkContext} + +class LogUrlsStandaloneSuite extends FunSuite with LocalSparkContext with BeforeAndAfter { + + /** Length of time to wait while draining listener events. */ + val WAIT_TIMEOUT_MILLIS = 10000 + + before { + sc = new SparkContext("local-cluster[2,1,512]", "test") + } + + test("verify log urls get propagated from workers") { + val listener = new SaveExecutorInfo + sc.addSparkListener(listener) + + val rdd1 = sc.parallelize(1 to 100, 4) + val rdd2 = rdd1.map(_.toString) + rdd2.setName("Target RDD") + rdd2.count() + + assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) + listener.addedExecutorInfos.values.foreach { info => + assert(info.logUrlMap.nonEmpty) + } + } + + private class SaveExecutorInfo extends SparkListener { + val addedExecutorInfos = mutable.Map[String, ExecutorInfo]() + + override def onExecutorAdded(executor: SparkListenerExecutorAdded) { + addedExecutorInfos(executor.executorId) = executor.executorInfo + } + } +} diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 065b7534cece6..46d745c4ecbfa 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -21,25 +21,30 @@ import java.io._ import scala.collection.mutable.ArrayBuffer +import com.google.common.base.Charsets.UTF_8 +import com.google.common.io.ByteStreams +import org.scalatest.FunSuite +import org.scalatest.Matchers +import org.scalatest.concurrent.Timeouts +import org.scalatest.time.SpanSugar._ + import org.apache.spark._ import org.apache.spark.deploy.SparkSubmit._ import org.apache.spark.util.{ResetSystemProperties, Utils} -import org.scalatest.FunSuite -import org.scalatest.Matchers // Note: this suite mixes in ResetSystemProperties because SparkSubmit.main() sets a bunch // of properties that neeed to be cleared after tests. -class SparkSubmitSuite extends FunSuite with Matchers with ResetSystemProperties { +class SparkSubmitSuite extends FunSuite with Matchers with ResetSystemProperties with Timeouts { def beforeAll() { System.setProperty("spark.testing", "true") } - val noOpOutputStream = new OutputStream { + private val noOpOutputStream = new OutputStream { def write(b: Int) = {} } /** Simple PrintStream that reads data into a buffer */ - class BufferPrintStream extends PrintStream(noOpOutputStream) { + private class BufferPrintStream extends PrintStream(noOpOutputStream) { var lineBuffer = ArrayBuffer[String]() override def println(line: String) { lineBuffer += line @@ -47,7 +52,7 @@ class SparkSubmitSuite extends FunSuite with Matchers with ResetSystemProperties } /** Returns true if the script exits and the given search string is printed. */ - def testPrematureExit(input: Array[String], searchString: String) = { + private def testPrematureExit(input: Array[String], searchString: String) = { val printStream = new BufferPrintStream() SparkSubmit.printStream = printStream @@ -138,7 +143,7 @@ class SparkSubmitSuite extends FunSuite with Matchers with ResetSystemProperties "thejar.jar", "arg1", "arg2") val appArgs = new SparkSubmitArguments(clArgs) - val (childArgs, classpath, sysProps, mainClass) = createLaunchEnv(appArgs) + val (childArgs, classpath, sysProps, mainClass) = prepareSubmitEnvironment(appArgs) val childArgsStr = childArgs.mkString(" ") childArgsStr should include ("--class org.SomeClass") childArgsStr should include ("--executor-memory 5g") @@ -177,7 +182,7 @@ class SparkSubmitSuite extends FunSuite with Matchers with ResetSystemProperties "thejar.jar", "arg1", "arg2") val appArgs = new SparkSubmitArguments(clArgs) - val (childArgs, classpath, sysProps, mainClass) = createLaunchEnv(appArgs) + val (childArgs, classpath, sysProps, mainClass) = prepareSubmitEnvironment(appArgs) childArgs.mkString(" ") should be ("arg1 arg2") mainClass should be ("org.SomeClass") classpath should have length (4) @@ -198,6 +203,18 @@ class SparkSubmitSuite extends FunSuite with Matchers with ResetSystemProperties } test("handles standalone cluster mode") { + testStandaloneCluster(useRest = true) + } + + test("handles legacy standalone cluster mode") { + testStandaloneCluster(useRest = false) + } + + /** + * Test whether the launch environment is correctly set up in standalone cluster mode. + * @param useRest whether to use the REST submission gateway introduced in Spark 1.3 + */ + private def testStandaloneCluster(useRest: Boolean): Unit = { val clArgs = Seq( "--deploy-mode", "cluster", "--master", "spark://h:p", @@ -209,17 +226,26 @@ class SparkSubmitSuite extends FunSuite with Matchers with ResetSystemProperties "thejar.jar", "arg1", "arg2") val appArgs = new SparkSubmitArguments(clArgs) - val (childArgs, classpath, sysProps, mainClass) = createLaunchEnv(appArgs) + appArgs.useRest = useRest + val (childArgs, classpath, sysProps, mainClass) = prepareSubmitEnvironment(appArgs) val childArgsStr = childArgs.mkString(" ") - childArgsStr should startWith ("--memory 4g --cores 5 --supervise") - childArgsStr should include regex ("launch spark://h:p .*thejar.jar org.SomeClass arg1 arg2") - mainClass should be ("org.apache.spark.deploy.Client") - classpath should have size (0) - sysProps should have size (5) + if (useRest) { + childArgsStr should endWith ("thejar.jar org.SomeClass arg1 arg2") + mainClass should be ("org.apache.spark.deploy.rest.StandaloneRestClient") + } else { + childArgsStr should startWith ("--supervise --memory 4g --cores 5") + childArgsStr should include regex "launch spark://h:p .*thejar.jar org.SomeClass arg1 arg2" + mainClass should be ("org.apache.spark.deploy.Client") + } + classpath should have size 0 + sysProps should have size 8 sysProps.keys should contain ("SPARK_SUBMIT") sysProps.keys should contain ("spark.master") sysProps.keys should contain ("spark.app.name") sysProps.keys should contain ("spark.jars") + sysProps.keys should contain ("spark.driver.memory") + sysProps.keys should contain ("spark.driver.cores") + sysProps.keys should contain ("spark.driver.supervise") sysProps.keys should contain ("spark.shuffle.spill") sysProps("spark.shuffle.spill") should be ("false") } @@ -236,7 +262,7 @@ class SparkSubmitSuite extends FunSuite with Matchers with ResetSystemProperties "thejar.jar", "arg1", "arg2") val appArgs = new SparkSubmitArguments(clArgs) - val (childArgs, classpath, sysProps, mainClass) = createLaunchEnv(appArgs) + val (childArgs, classpath, sysProps, mainClass) = prepareSubmitEnvironment(appArgs) childArgs.mkString(" ") should be ("arg1 arg2") mainClass should be ("org.SomeClass") classpath should have length (1) @@ -258,7 +284,7 @@ class SparkSubmitSuite extends FunSuite with Matchers with ResetSystemProperties "thejar.jar", "arg1", "arg2") val appArgs = new SparkSubmitArguments(clArgs) - val (childArgs, classpath, sysProps, mainClass) = createLaunchEnv(appArgs) + val (childArgs, classpath, sysProps, mainClass) = prepareSubmitEnvironment(appArgs) childArgs.mkString(" ") should be ("arg1 arg2") mainClass should be ("org.SomeClass") classpath should have length (1) @@ -278,7 +304,7 @@ class SparkSubmitSuite extends FunSuite with Matchers with ResetSystemProperties "thejar.jar", "arg1", "arg2") val appArgs = new SparkSubmitArguments(clArgs) - val (_, _, sysProps, mainClass) = createLaunchEnv(appArgs) + val (_, _, sysProps, mainClass) = prepareSubmitEnvironment(appArgs) sysProps("spark.executor.memory") should be ("5g") sysProps("spark.master") should be ("yarn-cluster") mainClass should be ("org.apache.spark.deploy.yarn.Client") @@ -290,7 +316,6 @@ class SparkSubmitSuite extends FunSuite with Matchers with ResetSystemProperties "--class", SimpleApplicationTest.getClass.getName.stripSuffix("$"), "--name", "testApp", "--master", "local", - "--conf", "spark.ui.enabled=false", unusedJar.toString) runSparkSubmit(args) } @@ -305,8 +330,21 @@ class SparkSubmitSuite extends FunSuite with Matchers with ResetSystemProperties "--name", "testApp", "--master", "local-cluster[2,1,512]", "--jars", jarsString, + unusedJar.toString, "SparkSubmitClassA", "SparkSubmitClassB") + runSparkSubmit(args) + } + + test("includes jars passed in through --packages") { + val unusedJar = TestUtils.createJarWithClasses(Seq.empty) + val packagesString = "com.databricks:spark-csv_2.10:0.1,com.databricks:spark-avro_2.10:0.1" + val args = Seq( + "--class", JarCreationTest.getClass.getName.stripSuffix("$"), + "--name", "testApp", + "--master", "local-cluster[2,1,512]", + "--packages", packagesString, "--conf", "spark.ui.enabled=false", - unusedJar.toString) + unusedJar.toString, + "com.databricks.spark.csv.DefaultSource", "com.databricks.spark.avro.DefaultSource") runSparkSubmit(args) } @@ -324,7 +362,7 @@ class SparkSubmitSuite extends FunSuite with Matchers with ResetSystemProperties "--files", files, "thejar.jar") val appArgs = new SparkSubmitArguments(clArgs) - val sysProps = SparkSubmit.createLaunchEnv(appArgs)._3 + val sysProps = SparkSubmit.prepareSubmitEnvironment(appArgs)._3 appArgs.jars should be (Utils.resolveURIs(jars)) appArgs.files should be (Utils.resolveURIs(files)) sysProps("spark.jars") should be (Utils.resolveURIs(jars + ",thejar.jar")) @@ -339,7 +377,7 @@ class SparkSubmitSuite extends FunSuite with Matchers with ResetSystemProperties "thejar.jar" ) val appArgs2 = new SparkSubmitArguments(clArgs2) - val sysProps2 = SparkSubmit.createLaunchEnv(appArgs2)._3 + val sysProps2 = SparkSubmit.prepareSubmitEnvironment(appArgs2)._3 appArgs2.files should be (Utils.resolveURIs(files)) appArgs2.archives should be (Utils.resolveURIs(archives)) sysProps2("spark.yarn.dist.files") should be (Utils.resolveURIs(files)) @@ -352,7 +390,7 @@ class SparkSubmitSuite extends FunSuite with Matchers with ResetSystemProperties "mister.py" ) val appArgs3 = new SparkSubmitArguments(clArgs3) - val sysProps3 = SparkSubmit.createLaunchEnv(appArgs3)._3 + val sysProps3 = SparkSubmit.prepareSubmitEnvironment(appArgs3)._3 appArgs3.pyFiles should be (Utils.resolveURIs(pyFiles)) sysProps3("spark.submit.pyFiles") should be ( PythonRunner.formatPaths(Utils.resolveURIs(pyFiles)).mkString(",")) @@ -377,7 +415,7 @@ class SparkSubmitSuite extends FunSuite with Matchers with ResetSystemProperties "thejar.jar" ) val appArgs = new SparkSubmitArguments(clArgs) - val sysProps = SparkSubmit.createLaunchEnv(appArgs)._3 + val sysProps = SparkSubmit.prepareSubmitEnvironment(appArgs)._3 sysProps("spark.jars") should be(Utils.resolveURIs(jars + ",thejar.jar")) sysProps("spark.files") should be(Utils.resolveURIs(files)) @@ -394,7 +432,7 @@ class SparkSubmitSuite extends FunSuite with Matchers with ResetSystemProperties "thejar.jar" ) val appArgs2 = new SparkSubmitArguments(clArgs2) - val sysProps2 = SparkSubmit.createLaunchEnv(appArgs2)._3 + val sysProps2 = SparkSubmit.prepareSubmitEnvironment(appArgs2)._3 sysProps2("spark.yarn.dist.files") should be(Utils.resolveURIs(files)) sysProps2("spark.yarn.dist.archives") should be(Utils.resolveURIs(archives)) @@ -409,11 +447,24 @@ class SparkSubmitSuite extends FunSuite with Matchers with ResetSystemProperties "mister.py" ) val appArgs3 = new SparkSubmitArguments(clArgs3) - val sysProps3 = SparkSubmit.createLaunchEnv(appArgs3)._3 + val sysProps3 = SparkSubmit.prepareSubmitEnvironment(appArgs3)._3 sysProps3("spark.submit.pyFiles") should be( PythonRunner.formatPaths(Utils.resolveURIs(pyFiles)).mkString(",")) } + test("user classpath first in driver") { + val systemJar = TestUtils.createJarWithFiles(Map("test.resource" -> "SYSTEM")) + val userJar = TestUtils.createJarWithFiles(Map("test.resource" -> "USER")) + val args = Seq( + "--class", UserClasspathFirstTest.getClass.getName.stripSuffix("$"), + "--name", "testApp", + "--master", "local", + "--conf", "spark.driver.extraClassPath=" + systemJar, + "--conf", "spark.driver.userClassPathFirst=true", + userJar.toString) + runSparkSubmit(args) + } + test("SPARK_CONF_DIR overrides spark-defaults.conf") { forConfDir(Map("spark.executor.memory" -> "2.3g")) { path => val unusedJar = TestUtils.createJarWithClasses(Seq.empty) @@ -425,20 +476,23 @@ class SparkSubmitSuite extends FunSuite with Matchers with ResetSystemProperties val appArgs = new SparkSubmitArguments(args, Map("SPARK_CONF_DIR" -> path)) assert(appArgs.propertiesFile != null) assert(appArgs.propertiesFile.startsWith(path)) - appArgs.executorMemory should be ("2.3g") + appArgs.executorMemory should be ("2.3g") } } // NOTE: This is an expensive operation in terms of time (10 seconds+). Use sparingly. - def runSparkSubmit(args: Seq[String]): String = { + private def runSparkSubmit(args: Seq[String]): Unit = { val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!")) - Utils.executeAndGetOutput( + val process = Utils.executeCommand( Seq("./bin/spark-submit") ++ args, new File(sparkHome), Map("SPARK_TESTING" -> "1", "SPARK_HOME" -> sparkHome)) + failAfter(60 seconds) { process.waitFor() } + // Ensure we still kill the process in case it timed out + process.destroy() } - def forConfDir(defaults: Map[String, String]) (f: String => Unit) = { + private def forConfDir(defaults: Map[String, String]) (f: String => Unit) = { val tmpDir = Utils.createTempDir() val defaultsConf = new File(tmpDir.getAbsolutePath, "spark-defaults.conf") @@ -463,8 +517,8 @@ object JarCreationTest extends Logging { val result = sc.makeRDD(1 to 100, 10).mapPartitions { x => var exception: String = null try { - Class.forName("SparkSubmitClassA", true, Thread.currentThread().getContextClassLoader) - Class.forName("SparkSubmitClassA", true, Thread.currentThread().getContextClassLoader) + Class.forName(args(0), true, Thread.currentThread().getContextClassLoader) + Class.forName(args(1), true, Thread.currentThread().getContextClassLoader) } catch { case t: Throwable => exception = t + "\n" + t.getStackTraceString @@ -502,3 +556,15 @@ object SimpleApplicationTest { } } } + +object UserClasspathFirstTest { + def main(args: Array[String]) { + val ccl = Thread.currentThread().getContextClassLoader() + val resource = ccl.getResourceAsStream("test.resource") + val bytes = ByteStreams.toByteArray(resource) + val contents = new String(bytes, 0, bytes.length, UTF_8) + if (contents != "USER") { + throw new SparkException("Should have read user resource, but instead read: " + contents) + } + } +} diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala new file mode 100644 index 0000000000000..ad62b35f624f6 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala @@ -0,0 +1,124 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy + +import java.io.{PrintStream, OutputStream, File} + +import scala.collection.mutable.ArrayBuffer + +import org.scalatest.{BeforeAndAfterAll, FunSuite} + +import org.apache.ivy.core.module.descriptor.MDArtifact +import org.apache.ivy.plugins.resolver.IBiblioResolver + +class SparkSubmitUtilsSuite extends FunSuite with BeforeAndAfterAll { + + private val noOpOutputStream = new OutputStream { + def write(b: Int) = {} + } + + /** Simple PrintStream that reads data into a buffer */ + private class BufferPrintStream extends PrintStream(noOpOutputStream) { + var lineBuffer = ArrayBuffer[String]() + override def println(line: String) { + lineBuffer += line + } + } + + override def beforeAll() { + super.beforeAll() + // We don't want to write logs during testing + SparkSubmitUtils.printStream = new BufferPrintStream + } + + test("incorrect maven coordinate throws error") { + val coordinates = Seq("a:b: ", " :a:b", "a: :b", "a:b:", ":a:b", "a::b", "::", "a:b", "a") + for (coordinate <- coordinates) { + intercept[IllegalArgumentException] { + SparkSubmitUtils.extractMavenCoordinates(coordinate) + } + } + } + + test("create repo resolvers") { + val resolver1 = SparkSubmitUtils.createRepoResolvers(None) + // should have central and spark-packages by default + assert(resolver1.getResolvers.size() === 2) + assert(resolver1.getResolvers.get(0).asInstanceOf[IBiblioResolver].getName === "central") + assert(resolver1.getResolvers.get(1).asInstanceOf[IBiblioResolver].getName === "spark-packages") + + val repos = "a/1,b/2,c/3" + val resolver2 = SparkSubmitUtils.createRepoResolvers(Option(repos)) + assert(resolver2.getResolvers.size() === 5) + val expected = repos.split(",").map(r => s"$r/") + resolver2.getResolvers.toArray.zipWithIndex.foreach { case (resolver: IBiblioResolver, i) => + if (i == 0) { + assert(resolver.getName === "central") + } else if (i == 1) { + assert(resolver.getName === "spark-packages") + } else { + assert(resolver.getName === s"repo-${i - 1}") + assert(resolver.getRoot === expected(i - 2)) + } + } + } + + test("add dependencies works correctly") { + val md = SparkSubmitUtils.getModuleDescriptor + val artifacts = SparkSubmitUtils.extractMavenCoordinates("com.databricks:spark-csv_2.10:0.1," + + "com.databricks:spark-avro_2.10:0.1") + + SparkSubmitUtils.addDependenciesToIvy(md, artifacts, "default") + assert(md.getDependencies.length === 2) + } + + test("ivy path works correctly") { + val ivyPath = "dummy/ivy" + val md = SparkSubmitUtils.getModuleDescriptor + val artifacts = for (i <- 0 until 3) yield new MDArtifact(md, s"jar-$i", "jar", "jar") + var jPaths = SparkSubmitUtils.resolveDependencyPaths(artifacts.toArray, new File(ivyPath)) + for (i <- 0 until 3) { + val index = jPaths.indexOf(ivyPath) + assert(index >= 0) + jPaths = jPaths.substring(index + ivyPath.length) + } + // end to end + val jarPath = SparkSubmitUtils.resolveMavenCoordinates( + "com.databricks:spark-csv_2.10:0.1", None, Option(ivyPath), true) + assert(jarPath.indexOf(ivyPath) >= 0, "should use non-default ivy path") + } + + test("search for artifact at other repositories") { + val path = SparkSubmitUtils.resolveMavenCoordinates("com.agimatec:agimatec-validation:0.9.3", + Option("https://oss.sonatype.org/content/repositories/agimatec/"), None, true) + assert(path.indexOf("agimatec-validation") >= 0, "should find package. If it doesn't, check" + + "if package still exists. If it has been removed, replace the example in this test.") + } + + test("dependency not found throws RuntimeException") { + intercept[RuntimeException] { + SparkSubmitUtils.resolveMavenCoordinates("a:b:c", None, None, true) + } + } + + test("neglects Spark and Spark's dependencies") { + val path = SparkSubmitUtils.resolveMavenCoordinates( + "org.apache.spark:spark-core_2.10:1.2.0", None, None, true) + assert(path === "", "should return empty path") + } +} diff --git a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala index 8379883e065e7..85939eaadccc7 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala @@ -37,13 +37,8 @@ class FsHistoryProviderSuite extends FunSuite with BeforeAndAfter with Matchers private var testDir: File = null - private var provider: FsHistoryProvider = null - before { testDir = Utils.createTempDir() - provider = new FsHistoryProvider(new SparkConf() - .set("spark.history.fs.logDirectory", testDir.getAbsolutePath()) - .set("spark.history.fs.updateInterval", "0")) } after { @@ -51,40 +46,41 @@ class FsHistoryProviderSuite extends FunSuite with BeforeAndAfter with Matchers } test("Parse new and old application logs") { - val conf = new SparkConf() - .set("spark.history.fs.logDirectory", testDir.getAbsolutePath()) - .set("spark.history.fs.updateInterval", "0") - val provider = new FsHistoryProvider(conf) + val provider = new FsHistoryProvider(createTestConf()) // Write a new-style application log. - val logFile1 = new File(testDir, "new1") - writeFile(logFile1, true, None, - SparkListenerApplicationStart("app1-1", None, 1L, "test"), - SparkListenerApplicationEnd(2L) + val newAppComplete = new File(testDir, "new1") + writeFile(newAppComplete, true, None, + SparkListenerApplicationStart("new-app-complete", None, 1L, "test"), + SparkListenerApplicationEnd(4L) ) // Write an unfinished app, new-style. - val logFile2 = new File(testDir, "new2" + EventLoggingListener.IN_PROGRESS) - writeFile(logFile2, true, None, - SparkListenerApplicationStart("app2-2", None, 1L, "test") + val newAppIncomplete = new File(testDir, "new2" + EventLoggingListener.IN_PROGRESS) + writeFile(newAppIncomplete, true, None, + SparkListenerApplicationStart("new-app-incomplete", None, 1L, "test") ) // Write an old-style application log. - val oldLog = new File(testDir, "old1") - oldLog.mkdir() - createEmptyFile(new File(oldLog, provider.SPARK_VERSION_PREFIX + "1.0")) - writeFile(new File(oldLog, provider.LOG_PREFIX + "1"), false, None, - SparkListenerApplicationStart("app3", None, 2L, "test"), + val oldAppComplete = new File(testDir, "old1") + oldAppComplete.mkdir() + createEmptyFile(new File(oldAppComplete, provider.SPARK_VERSION_PREFIX + "1.0")) + writeFile(new File(oldAppComplete, provider.LOG_PREFIX + "1"), false, None, + SparkListenerApplicationStart("old-app-complete", None, 2L, "test"), SparkListenerApplicationEnd(3L) ) - createEmptyFile(new File(oldLog, provider.APPLICATION_COMPLETE)) + createEmptyFile(new File(oldAppComplete, provider.APPLICATION_COMPLETE)) + + // Check for logs so that we force the older unfinished app to be loaded, to make + // sure unfinished apps are also sorted correctly. + provider.checkForLogs() // Write an unfinished app, old-style. - val oldLog2 = new File(testDir, "old2") - oldLog2.mkdir() - createEmptyFile(new File(oldLog2, provider.SPARK_VERSION_PREFIX + "1.0")) - writeFile(new File(oldLog2, provider.LOG_PREFIX + "1"), false, None, - SparkListenerApplicationStart("app4", None, 2L, "test") + val oldAppIncomplete = new File(testDir, "old2") + oldAppIncomplete.mkdir() + createEmptyFile(new File(oldAppIncomplete, provider.SPARK_VERSION_PREFIX + "1.0")) + writeFile(new File(oldAppIncomplete, provider.LOG_PREFIX + "1"), false, None, + SparkListenerApplicationStart("old-app-incomplete", None, 2L, "test") ) // Force a reload of data from the log directory, and check that both logs are loaded. @@ -96,14 +92,14 @@ class FsHistoryProviderSuite extends FunSuite with BeforeAndAfter with Matchers list.size should be (4) list.count(e => e.completed) should be (2) - list(0) should be (ApplicationHistoryInfo(oldLog.getName(), "app3", 2L, 3L, - oldLog.lastModified(), "test", true)) - list(1) should be (ApplicationHistoryInfo(logFile1.getName(), "app1-1", 1L, 2L, - logFile1.lastModified(), "test", true)) - list(2) should be (ApplicationHistoryInfo(oldLog2.getName(), "app4", 2L, -1L, - oldLog2.lastModified(), "test", false)) - list(3) should be (ApplicationHistoryInfo(logFile2.getName(), "app2-2", 1L, -1L, - logFile2.lastModified(), "test", false)) + list(0) should be (ApplicationHistoryInfo(newAppComplete.getName(), "new-app-complete", 1L, 4L, + newAppComplete.lastModified(), "test", true)) + list(1) should be (ApplicationHistoryInfo(oldAppComplete.getName(), "old-app-complete", 2L, 3L, + oldAppComplete.lastModified(), "test", true)) + list(2) should be (ApplicationHistoryInfo(oldAppIncomplete.getName(), "old-app-incomplete", 2L, + -1L, oldAppIncomplete.lastModified(), "test", false)) + list(3) should be (ApplicationHistoryInfo(newAppIncomplete.getName(), "new-app-incomplete", 1L, + -1L, newAppIncomplete.lastModified(), "test", false)) // Make sure the UI can be rendered. list.foreach { case info => @@ -113,6 +109,7 @@ class FsHistoryProviderSuite extends FunSuite with BeforeAndAfter with Matchers } test("Parse legacy logs with compression codec set") { + val provider = new FsHistoryProvider(createTestConf()) val testCodecs = List((classOf[LZFCompressionCodec].getName(), true), (classOf[SnappyCompressionCodec].getName(), true), ("invalid.codec", false)) @@ -156,10 +153,7 @@ class FsHistoryProviderSuite extends FunSuite with BeforeAndAfter with Matchers ) logFile2.setReadable(false, false) - val conf = new SparkConf() - .set("spark.history.fs.logDirectory", testDir.getAbsolutePath()) - .set("spark.history.fs.updateInterval", "0") - val provider = new FsHistoryProvider(conf) + val provider = new FsHistoryProvider(createTestConf()) provider.checkForLogs() val list = provider.getListing().toSeq @@ -167,6 +161,42 @@ class FsHistoryProviderSuite extends FunSuite with BeforeAndAfter with Matchers list.size should be (1) } + test("history file is renamed from inprogress to completed") { + val provider = new FsHistoryProvider(createTestConf()) + + val logFile1 = new File(testDir, "app1" + EventLoggingListener.IN_PROGRESS) + writeFile(logFile1, true, None, + SparkListenerApplicationStart("app1", Some("app1"), 1L, "test"), + SparkListenerApplicationEnd(2L) + ) + provider.checkForLogs() + val appListBeforeRename = provider.getListing() + appListBeforeRename.size should be (1) + appListBeforeRename.head.logPath should endWith(EventLoggingListener.IN_PROGRESS) + + logFile1.renameTo(new File(testDir, "app1")) + provider.checkForLogs() + val appListAfterRename = provider.getListing() + appListAfterRename.size should be (1) + appListAfterRename.head.logPath should not endWith(EventLoggingListener.IN_PROGRESS) + } + + test("SPARK-5582: empty log directory") { + val provider = new FsHistoryProvider(createTestConf()) + + val logFile1 = new File(testDir, "app1" + EventLoggingListener.IN_PROGRESS) + writeFile(logFile1, true, None, + SparkListenerApplicationStart("app1", Some("app1"), 1L, "test"), + SparkListenerApplicationEnd(2L)) + + val oldLog = new File(testDir, "old1") + oldLog.mkdir() + + provider.checkForLogs() + val appListAfterRename = provider.getListing() + appListAfterRename.size should be (1) + } + private def writeFile(file: File, isNewFormat: Boolean, codec: Option[CompressionCodec], events: SparkListenerEvent*) = { val out = @@ -188,4 +218,8 @@ class FsHistoryProviderSuite extends FunSuite with BeforeAndAfter with Matchers new FileOutputStream(file).close() } + private def createTestConf(): SparkConf = { + new SparkConf().set("spark.history.fs.logDirectory", testDir.getAbsolutePath()) + } + } diff --git a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala index 3d2335f9b3637..34c74d87f0a62 100644 --- a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala @@ -20,30 +20,46 @@ package org.apache.spark.deploy.master import akka.actor.Address import org.scalatest.FunSuite -import org.apache.spark.SparkException +import org.apache.spark.{SSLOptions, SparkConf, SparkException} class MasterSuite extends FunSuite { test("toAkkaUrl") { - val akkaUrl = Master.toAkkaUrl("spark://1.2.3.4:1234") + val conf = new SparkConf(loadDefaults = false) + val akkaUrl = Master.toAkkaUrl("spark://1.2.3.4:1234", "akka.tcp") assert("akka.tcp://sparkMaster@1.2.3.4:1234/user/Master" === akkaUrl) } + test("toAkkaUrl with SSL") { + val conf = new SparkConf(loadDefaults = false) + val akkaUrl = Master.toAkkaUrl("spark://1.2.3.4:1234", "akka.ssl.tcp") + assert("akka.ssl.tcp://sparkMaster@1.2.3.4:1234/user/Master" === akkaUrl) + } + test("toAkkaUrl: a typo url") { + val conf = new SparkConf(loadDefaults = false) val e = intercept[SparkException] { - Master.toAkkaUrl("spark://1.2. 3.4:1234") + Master.toAkkaUrl("spark://1.2. 3.4:1234", "akka.tcp") } assert("Invalid master URL: spark://1.2. 3.4:1234" === e.getMessage) } test("toAkkaAddress") { - val address = Master.toAkkaAddress("spark://1.2.3.4:1234") + val conf = new SparkConf(loadDefaults = false) + val address = Master.toAkkaAddress("spark://1.2.3.4:1234", "akka.tcp") assert(Address("akka.tcp", "sparkMaster", "1.2.3.4", 1234) === address) } + test("toAkkaAddress with SSL") { + val conf = new SparkConf(loadDefaults = false) + val address = Master.toAkkaAddress("spark://1.2.3.4:1234", "akka.ssl.tcp") + assert(Address("akka.ssl.tcp", "sparkMaster", "1.2.3.4", 1234) === address) + } + test("toAkkaAddress: a typo url") { + val conf = new SparkConf(loadDefaults = false) val e = intercept[SparkException] { - Master.toAkkaAddress("spark://1.2. 3.4:1234") + Master.toAkkaAddress("spark://1.2. 3.4:1234", "akka.tcp") } assert("Invalid master URL: spark://1.2. 3.4:1234" === e.getMessage) } diff --git a/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala new file mode 100644 index 0000000000000..2fa90e3bd1c63 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala @@ -0,0 +1,606 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.rest + +import java.io.DataOutputStream +import java.net.{HttpURLConnection, URL} +import javax.servlet.http.HttpServletResponse + +import scala.collection.mutable + +import akka.actor.{Actor, ActorRef, ActorSystem, Props} +import com.google.common.base.Charsets +import org.scalatest.{BeforeAndAfterEach, FunSuite} +import org.json4s.JsonAST._ +import org.json4s.jackson.JsonMethods._ + +import org.apache.spark._ +import org.apache.spark.util.{AkkaUtils, Utils} +import org.apache.spark.deploy.DeployMessages._ +import org.apache.spark.deploy.{SparkSubmit, SparkSubmitArguments} +import org.apache.spark.deploy.master.DriverState._ + +/** + * Tests for the REST application submission protocol used in standalone cluster mode. + */ +class StandaloneRestSubmitSuite extends FunSuite with BeforeAndAfterEach { + private val client = new StandaloneRestClient + private var actorSystem: Option[ActorSystem] = None + private var server: Option[StandaloneRestServer] = None + + override def afterEach() { + actorSystem.foreach(_.shutdown()) + server.foreach(_.stop()) + } + + test("construct submit request") { + val appArgs = Array("one", "two", "three") + val sparkProperties = Map("spark.app.name" -> "pi") + val environmentVariables = Map("SPARK_ONE" -> "UN", "SPARK_TWO" -> "DEUX") + val request = client.constructSubmitRequest( + "my-app-resource", "my-main-class", appArgs, sparkProperties, environmentVariables) + assert(request.action === Utils.getFormattedClassName(request)) + assert(request.clientSparkVersion === SPARK_VERSION) + assert(request.appResource === "my-app-resource") + assert(request.mainClass === "my-main-class") + assert(request.appArgs === appArgs) + assert(request.sparkProperties === sparkProperties) + assert(request.environmentVariables === environmentVariables) + } + + test("create submission") { + val submittedDriverId = "my-driver-id" + val submitMessage = "your driver is submitted" + val masterUrl = startDummyServer(submitId = submittedDriverId, submitMessage = submitMessage) + val appArgs = Array("one", "two", "four") + val request = constructSubmitRequest(masterUrl, appArgs) + assert(request.appArgs === appArgs) + assert(request.sparkProperties("spark.master") === masterUrl) + val response = client.createSubmission(masterUrl, request) + val submitResponse = getSubmitResponse(response) + assert(submitResponse.action === Utils.getFormattedClassName(submitResponse)) + assert(submitResponse.serverSparkVersion === SPARK_VERSION) + assert(submitResponse.message === submitMessage) + assert(submitResponse.submissionId === submittedDriverId) + assert(submitResponse.success) + } + + test("create submission from main method") { + val submittedDriverId = "your-driver-id" + val submitMessage = "my driver is submitted" + val masterUrl = startDummyServer(submitId = submittedDriverId, submitMessage = submitMessage) + val conf = new SparkConf(loadDefaults = false) + conf.set("spark.master", masterUrl) + conf.set("spark.app.name", "dreamer") + val appArgs = Array("one", "two", "six") + // main method calls this + val response = StandaloneRestClient.run("app-resource", "main-class", appArgs, conf) + val submitResponse = getSubmitResponse(response) + assert(submitResponse.action === Utils.getFormattedClassName(submitResponse)) + assert(submitResponse.serverSparkVersion === SPARK_VERSION) + assert(submitResponse.message === submitMessage) + assert(submitResponse.submissionId === submittedDriverId) + assert(submitResponse.success) + } + + test("kill submission") { + val submissionId = "my-lyft-driver" + val killMessage = "your driver is killed" + val masterUrl = startDummyServer(killMessage = killMessage) + val response = client.killSubmission(masterUrl, submissionId) + val killResponse = getKillResponse(response) + assert(killResponse.action === Utils.getFormattedClassName(killResponse)) + assert(killResponse.serverSparkVersion === SPARK_VERSION) + assert(killResponse.message === killMessage) + assert(killResponse.submissionId === submissionId) + assert(killResponse.success) + } + + test("request submission status") { + val submissionId = "my-uber-driver" + val submissionState = KILLED + val submissionException = new Exception("there was an irresponsible mix of alcohol and cars") + val masterUrl = startDummyServer(state = submissionState, exception = Some(submissionException)) + val response = client.requestSubmissionStatus(masterUrl, submissionId) + val statusResponse = getStatusResponse(response) + assert(statusResponse.action === Utils.getFormattedClassName(statusResponse)) + assert(statusResponse.serverSparkVersion === SPARK_VERSION) + assert(statusResponse.message.contains(submissionException.getMessage)) + assert(statusResponse.submissionId === submissionId) + assert(statusResponse.driverState === submissionState.toString) + assert(statusResponse.success) + } + + test("create then kill") { + val masterUrl = startSmartServer() + val request = constructSubmitRequest(masterUrl) + val response1 = client.createSubmission(masterUrl, request) + val submitResponse = getSubmitResponse(response1) + assert(submitResponse.success) + assert(submitResponse.submissionId != null) + // kill submission that was just created + val submissionId = submitResponse.submissionId + val response2 = client.killSubmission(masterUrl, submissionId) + val killResponse = getKillResponse(response2) + assert(killResponse.success) + assert(killResponse.submissionId === submissionId) + } + + test("create then request status") { + val masterUrl = startSmartServer() + val request = constructSubmitRequest(masterUrl) + val response1 = client.createSubmission(masterUrl, request) + val submitResponse = getSubmitResponse(response1) + assert(submitResponse.success) + assert(submitResponse.submissionId != null) + // request status of submission that was just created + val submissionId = submitResponse.submissionId + val response2 = client.requestSubmissionStatus(masterUrl, submissionId) + val statusResponse = getStatusResponse(response2) + assert(statusResponse.success) + assert(statusResponse.submissionId === submissionId) + assert(statusResponse.driverState === RUNNING.toString) + } + + test("create then kill then request status") { + val masterUrl = startSmartServer() + val request = constructSubmitRequest(masterUrl) + val response1 = client.createSubmission(masterUrl, request) + val response2 = client.createSubmission(masterUrl, request) + val submitResponse1 = getSubmitResponse(response1) + val submitResponse2 = getSubmitResponse(response2) + assert(submitResponse1.success) + assert(submitResponse2.success) + assert(submitResponse1.submissionId != null) + assert(submitResponse2.submissionId != null) + val submissionId1 = submitResponse1.submissionId + val submissionId2 = submitResponse2.submissionId + // kill only submission 1, but not submission 2 + val response3 = client.killSubmission(masterUrl, submissionId1) + val killResponse = getKillResponse(response3) + assert(killResponse.success) + assert(killResponse.submissionId === submissionId1) + // request status for both submissions: 1 should be KILLED but 2 should be RUNNING still + val response4 = client.requestSubmissionStatus(masterUrl, submissionId1) + val response5 = client.requestSubmissionStatus(masterUrl, submissionId2) + val statusResponse1 = getStatusResponse(response4) + val statusResponse2 = getStatusResponse(response5) + assert(statusResponse1.submissionId === submissionId1) + assert(statusResponse2.submissionId === submissionId2) + assert(statusResponse1.driverState === KILLED.toString) + assert(statusResponse2.driverState === RUNNING.toString) + } + + test("kill or request status before create") { + val masterUrl = startSmartServer() + val doesNotExist = "does-not-exist" + // kill a non-existent submission + val response1 = client.killSubmission(masterUrl, doesNotExist) + val killResponse = getKillResponse(response1) + assert(!killResponse.success) + assert(killResponse.submissionId === doesNotExist) + // request status for a non-existent submission + val response2 = client.requestSubmissionStatus(masterUrl, doesNotExist) + val statusResponse = getStatusResponse(response2) + assert(!statusResponse.success) + assert(statusResponse.submissionId === doesNotExist) + } + + /* ---------------------------------------- * + | Aberrant client / server behavior | + * ---------------------------------------- */ + + test("good request paths") { + val masterUrl = startSmartServer() + val httpUrl = masterUrl.replace("spark://", "http://") + val v = StandaloneRestServer.PROTOCOL_VERSION + val json = constructSubmitRequest(masterUrl).toJson + val submitRequestPath = s"$httpUrl/$v/submissions/create" + val killRequestPath = s"$httpUrl/$v/submissions/kill" + val statusRequestPath = s"$httpUrl/$v/submissions/status" + val (response1, code1) = sendHttpRequestWithResponse(submitRequestPath, "POST", json) + val (response2, code2) = sendHttpRequestWithResponse(s"$killRequestPath/anything", "POST") + val (response3, code3) = sendHttpRequestWithResponse(s"$killRequestPath/any/thing", "POST") + val (response4, code4) = sendHttpRequestWithResponse(s"$statusRequestPath/anything", "GET") + val (response5, code5) = sendHttpRequestWithResponse(s"$statusRequestPath/any/thing", "GET") + // these should all succeed and the responses should be of the correct types + getSubmitResponse(response1) + val killResponse1 = getKillResponse(response2) + val killResponse2 = getKillResponse(response3) + val statusResponse1 = getStatusResponse(response4) + val statusResponse2 = getStatusResponse(response5) + assert(killResponse1.submissionId === "anything") + assert(killResponse2.submissionId === "any") + assert(statusResponse1.submissionId === "anything") + assert(statusResponse2.submissionId === "any") + assert(code1 === HttpServletResponse.SC_OK) + assert(code2 === HttpServletResponse.SC_OK) + assert(code3 === HttpServletResponse.SC_OK) + assert(code4 === HttpServletResponse.SC_OK) + assert(code5 === HttpServletResponse.SC_OK) + } + + test("good request paths, bad requests") { + val masterUrl = startSmartServer() + val httpUrl = masterUrl.replace("spark://", "http://") + val v = StandaloneRestServer.PROTOCOL_VERSION + val submitRequestPath = s"$httpUrl/$v/submissions/create" + val killRequestPath = s"$httpUrl/$v/submissions/kill" + val statusRequestPath = s"$httpUrl/$v/submissions/status" + val goodJson = constructSubmitRequest(masterUrl).toJson + val badJson1 = goodJson.replaceAll("action", "fraction") // invalid JSON + val badJson2 = goodJson.substring(goodJson.size / 2) // malformed JSON + val notJson = "\"hello, world\"" + val (response1, code1) = sendHttpRequestWithResponse(submitRequestPath, "POST") // missing JSON + val (response2, code2) = sendHttpRequestWithResponse(submitRequestPath, "POST", badJson1) + val (response3, code3) = sendHttpRequestWithResponse(submitRequestPath, "POST", badJson2) + val (response4, code4) = sendHttpRequestWithResponse(killRequestPath, "POST") // missing ID + val (response5, code5) = sendHttpRequestWithResponse(s"$killRequestPath/", "POST") + val (response6, code6) = sendHttpRequestWithResponse(statusRequestPath, "GET") // missing ID + val (response7, code7) = sendHttpRequestWithResponse(s"$statusRequestPath/", "GET") + val (response8, code8) = sendHttpRequestWithResponse(submitRequestPath, "POST", notJson) + // these should all fail as error responses + getErrorResponse(response1) + getErrorResponse(response2) + getErrorResponse(response3) + getErrorResponse(response4) + getErrorResponse(response5) + getErrorResponse(response6) + getErrorResponse(response7) + getErrorResponse(response8) + assert(code1 === HttpServletResponse.SC_BAD_REQUEST) + assert(code2 === HttpServletResponse.SC_BAD_REQUEST) + assert(code3 === HttpServletResponse.SC_BAD_REQUEST) + assert(code4 === HttpServletResponse.SC_BAD_REQUEST) + assert(code5 === HttpServletResponse.SC_BAD_REQUEST) + assert(code6 === HttpServletResponse.SC_BAD_REQUEST) + assert(code7 === HttpServletResponse.SC_BAD_REQUEST) + assert(code8 === HttpServletResponse.SC_BAD_REQUEST) + } + + test("bad request paths") { + val masterUrl = startSmartServer() + val httpUrl = masterUrl.replace("spark://", "http://") + val v = StandaloneRestServer.PROTOCOL_VERSION + val (response1, code1) = sendHttpRequestWithResponse(httpUrl, "GET") + val (response2, code2) = sendHttpRequestWithResponse(s"$httpUrl/", "GET") + val (response3, code3) = sendHttpRequestWithResponse(s"$httpUrl/$v", "GET") + val (response4, code4) = sendHttpRequestWithResponse(s"$httpUrl/$v/", "GET") + val (response5, code5) = sendHttpRequestWithResponse(s"$httpUrl/$v/submissions", "GET") + val (response6, code6) = sendHttpRequestWithResponse(s"$httpUrl/$v/submissions/", "GET") + val (response7, code7) = sendHttpRequestWithResponse(s"$httpUrl/$v/submissions/bad", "GET") + val (response8, code8) = sendHttpRequestWithResponse(s"$httpUrl/bad-version", "GET") + assert(code1 === HttpServletResponse.SC_BAD_REQUEST) + assert(code2 === HttpServletResponse.SC_BAD_REQUEST) + assert(code3 === HttpServletResponse.SC_BAD_REQUEST) + assert(code4 === HttpServletResponse.SC_BAD_REQUEST) + assert(code5 === HttpServletResponse.SC_BAD_REQUEST) + assert(code6 === HttpServletResponse.SC_BAD_REQUEST) + assert(code7 === HttpServletResponse.SC_BAD_REQUEST) + assert(code8 === StandaloneRestServer.SC_UNKNOWN_PROTOCOL_VERSION) + // all responses should be error responses + val errorResponse1 = getErrorResponse(response1) + val errorResponse2 = getErrorResponse(response2) + val errorResponse3 = getErrorResponse(response3) + val errorResponse4 = getErrorResponse(response4) + val errorResponse5 = getErrorResponse(response5) + val errorResponse6 = getErrorResponse(response6) + val errorResponse7 = getErrorResponse(response7) + val errorResponse8 = getErrorResponse(response8) + // only the incompatible version response should have server protocol version set + assert(errorResponse1.highestProtocolVersion === null) + assert(errorResponse2.highestProtocolVersion === null) + assert(errorResponse3.highestProtocolVersion === null) + assert(errorResponse4.highestProtocolVersion === null) + assert(errorResponse5.highestProtocolVersion === null) + assert(errorResponse6.highestProtocolVersion === null) + assert(errorResponse7.highestProtocolVersion === null) + assert(errorResponse8.highestProtocolVersion === StandaloneRestServer.PROTOCOL_VERSION) + } + + test("server returns unknown fields") { + val masterUrl = startSmartServer() + val httpUrl = masterUrl.replace("spark://", "http://") + val v = StandaloneRestServer.PROTOCOL_VERSION + val submitRequestPath = s"$httpUrl/$v/submissions/create" + val oldJson = constructSubmitRequest(masterUrl).toJson + val oldFields = parse(oldJson).asInstanceOf[JObject].obj + val newFields = oldFields ++ Seq( + JField("tomato", JString("not-a-fruit")), + JField("potato", JString("not-po-tah-to")) + ) + val newJson = pretty(render(JObject(newFields))) + // send two requests, one with the unknown fields and the other without + val (response1, code1) = sendHttpRequestWithResponse(submitRequestPath, "POST", oldJson) + val (response2, code2) = sendHttpRequestWithResponse(submitRequestPath, "POST", newJson) + val submitResponse1 = getSubmitResponse(response1) + val submitResponse2 = getSubmitResponse(response2) + assert(code1 === HttpServletResponse.SC_OK) + assert(code2 === HttpServletResponse.SC_OK) + // only the response to the modified request should have unknown fields set + assert(submitResponse1.unknownFields === null) + assert(submitResponse2.unknownFields === Array("tomato", "potato")) + } + + test("client handles faulty server") { + val masterUrl = startFaultyServer() + val httpUrl = masterUrl.replace("spark://", "http://") + val v = StandaloneRestServer.PROTOCOL_VERSION + val submitRequestPath = s"$httpUrl/$v/submissions/create" + val killRequestPath = s"$httpUrl/$v/submissions/kill/anything" + val statusRequestPath = s"$httpUrl/$v/submissions/status/anything" + val json = constructSubmitRequest(masterUrl).toJson + // server returns malformed response unwittingly + // client should throw an appropriate exception to indicate server failure + val conn1 = sendHttpRequest(submitRequestPath, "POST", json) + intercept[SubmitRestProtocolException] { client.readResponse(conn1) } + // server attempts to send invalid response, but fails internally on validation + // client should receive an error response as server is able to recover + val conn2 = sendHttpRequest(killRequestPath, "POST") + val response2 = client.readResponse(conn2) + getErrorResponse(response2) + assert(conn2.getResponseCode === HttpServletResponse.SC_INTERNAL_SERVER_ERROR) + // server explodes internally beyond recovery + // client should throw an appropriate exception to indicate server failure + val conn3 = sendHttpRequest(statusRequestPath, "GET") + intercept[SubmitRestProtocolException] { client.readResponse(conn3) } // empty response + assert(conn3.getResponseCode === HttpServletResponse.SC_INTERNAL_SERVER_ERROR) + } + + /* --------------------- * + | Helper methods | + * --------------------- */ + + /** Start a dummy server that responds to requests using the specified parameters. */ + private def startDummyServer( + submitId: String = "fake-driver-id", + submitMessage: String = "driver is submitted", + killMessage: String = "driver is killed", + state: DriverState = FINISHED, + exception: Option[Exception] = None): String = { + startServer(new DummyMaster(submitId, submitMessage, killMessage, state, exception)) + } + + /** Start a smarter dummy server that keeps track of submitted driver states. */ + private def startSmartServer(): String = { + startServer(new SmarterMaster) + } + + /** Start a dummy server that is faulty in many ways... */ + private def startFaultyServer(): String = { + startServer(new DummyMaster, faulty = true) + } + + /** + * Start a [[StandaloneRestServer]] that communicates with the given actor. + * If `faulty` is true, start an [[FaultyStandaloneRestServer]] instead. + * Return the master URL that corresponds to the address of this server. + */ + private def startServer(makeFakeMaster: => Actor, faulty: Boolean = false): String = { + val name = "test-standalone-rest-protocol" + val conf = new SparkConf + val localhost = Utils.localHostName() + val securityManager = new SecurityManager(conf) + val (_actorSystem, _) = AkkaUtils.createActorSystem(name, localhost, 0, conf, securityManager) + val fakeMasterRef = _actorSystem.actorOf(Props(makeFakeMaster)) + val _server = + if (faulty) { + new FaultyStandaloneRestServer(localhost, 0, fakeMasterRef, "spark://fake:7077", conf) + } else { + new StandaloneRestServer(localhost, 0, fakeMasterRef, "spark://fake:7077", conf) + } + val port = _server.start() + // set these to clean them up after every test + actorSystem = Some(_actorSystem) + server = Some(_server) + s"spark://$localhost:$port" + } + + /** Create a submit request with real parameters using Spark submit. */ + private def constructSubmitRequest( + masterUrl: String, + appArgs: Array[String] = Array.empty): CreateSubmissionRequest = { + val mainClass = "main-class-not-used" + val mainJar = "dummy-jar-not-used.jar" + val commandLineArgs = Array( + "--deploy-mode", "cluster", + "--master", masterUrl, + "--name", mainClass, + "--class", mainClass, + mainJar) ++ appArgs + val args = new SparkSubmitArguments(commandLineArgs) + val (_, _, sparkProperties, _) = SparkSubmit.prepareSubmitEnvironment(args) + client.constructSubmitRequest( + mainJar, mainClass, appArgs, sparkProperties.toMap, Map.empty) + } + + /** Return the response as a submit response, or fail with error otherwise. */ + private def getSubmitResponse(response: SubmitRestProtocolResponse): CreateSubmissionResponse = { + response match { + case s: CreateSubmissionResponse => s + case e: ErrorResponse => fail(s"Server returned error: ${e.message}") + case r => fail(s"Expected submit response. Actual: ${r.toJson}") + } + } + + /** Return the response as a kill response, or fail with error otherwise. */ + private def getKillResponse(response: SubmitRestProtocolResponse): KillSubmissionResponse = { + response match { + case k: KillSubmissionResponse => k + case e: ErrorResponse => fail(s"Server returned error: ${e.message}") + case r => fail(s"Expected kill response. Actual: ${r.toJson}") + } + } + + /** Return the response as a status response, or fail with error otherwise. */ + private def getStatusResponse(response: SubmitRestProtocolResponse): SubmissionStatusResponse = { + response match { + case s: SubmissionStatusResponse => s + case e: ErrorResponse => fail(s"Server returned error: ${e.message}") + case r => fail(s"Expected status response. Actual: ${r.toJson}") + } + } + + /** Return the response as an error response, or fail if the response was not an error. */ + private def getErrorResponse(response: SubmitRestProtocolResponse): ErrorResponse = { + response match { + case e: ErrorResponse => e + case r => fail(s"Expected error response. Actual: ${r.toJson}") + } + } + + /** + * Send an HTTP request to the given URL using the method and the body specified. + * Return the connection object. + */ + private def sendHttpRequest( + url: String, + method: String, + body: String = ""): HttpURLConnection = { + val conn = new URL(url).openConnection().asInstanceOf[HttpURLConnection] + conn.setRequestMethod(method) + if (body.nonEmpty) { + conn.setDoOutput(true) + val out = new DataOutputStream(conn.getOutputStream) + out.write(body.getBytes(Charsets.UTF_8)) + out.close() + } + conn + } + + /** + * Send an HTTP request to the given URL using the method and the body specified. + * Return a 2-tuple of the response message from the server and the response code. + */ + private def sendHttpRequestWithResponse( + url: String, + method: String, + body: String = ""): (SubmitRestProtocolResponse, Int) = { + val conn = sendHttpRequest(url, method, body) + (client.readResponse(conn), conn.getResponseCode) + } +} + +/** + * A mock standalone Master that responds with dummy messages. + * In all responses, the success parameter is always true. + */ +private class DummyMaster( + submitId: String = "fake-driver-id", + submitMessage: String = "submitted", + killMessage: String = "killed", + state: DriverState = FINISHED, + exception: Option[Exception] = None) + extends Actor { + + override def receive = { + case RequestSubmitDriver(driverDesc) => + sender ! SubmitDriverResponse(success = true, Some(submitId), submitMessage) + case RequestKillDriver(driverId) => + sender ! KillDriverResponse(driverId, success = true, killMessage) + case RequestDriverStatus(driverId) => + sender ! DriverStatusResponse(found = true, Some(state), None, None, exception) + } +} + +/** + * A mock standalone Master that keeps track of drivers that have been submitted. + * + * If a driver is submitted, its state is immediately set to RUNNING. + * If an existing driver is killed, its state is immediately set to KILLED. + * If an existing driver's status is requested, its state is returned in the response. + * Submits are always successful while kills and status requests are successful only + * if the driver was submitted in the past. + */ +private class SmarterMaster extends Actor { + private var counter: Int = 0 + private val submittedDrivers = new mutable.HashMap[String, DriverState] + + override def receive = { + case RequestSubmitDriver(driverDesc) => + val driverId = s"driver-$counter" + submittedDrivers(driverId) = RUNNING + counter += 1 + sender ! SubmitDriverResponse(success = true, Some(driverId), "submitted") + + case RequestKillDriver(driverId) => + val success = submittedDrivers.contains(driverId) + if (success) { + submittedDrivers(driverId) = KILLED + } + sender ! KillDriverResponse(driverId, success, "killed") + + case RequestDriverStatus(driverId) => + val found = submittedDrivers.contains(driverId) + val state = submittedDrivers.get(driverId) + sender ! DriverStatusResponse(found, state, None, None, None) + } +} + +/** + * A [[StandaloneRestServer]] that is faulty in many ways. + * + * When handling a submit request, the server returns a malformed JSON. + * When handling a kill request, the server returns an invalid JSON. + * When handling a status request, the server throws an internal exception. + * The purpose of this class is to test that client handles these cases gracefully. + */ +private class FaultyStandaloneRestServer( + host: String, + requestedPort: Int, + masterActor: ActorRef, + masterUrl: String, + masterConf: SparkConf) + extends StandaloneRestServer(host, requestedPort, masterActor, masterUrl, masterConf) { + + protected override val contextToServlet = Map[String, StandaloneRestServlet]( + s"$baseContext/create/*" -> new MalformedSubmitServlet, + s"$baseContext/kill/*" -> new InvalidKillServlet, + s"$baseContext/status/*" -> new ExplodingStatusServlet, + "/*" -> new ErrorServlet + ) + + /** A faulty servlet that produces malformed responses. */ + class MalformedSubmitServlet extends SubmitRequestServlet(masterActor, masterUrl, masterConf) { + protected override def sendResponse( + responseMessage: SubmitRestProtocolResponse, + responseServlet: HttpServletResponse): Unit = { + val badJson = responseMessage.toJson.drop(10).dropRight(20) + responseServlet.getWriter.write(badJson) + } + } + + /** A faulty servlet that produces invalid responses. */ + class InvalidKillServlet extends KillRequestServlet(masterActor, masterConf) { + protected override def handleKill(submissionId: String): KillSubmissionResponse = { + val k = super.handleKill(submissionId) + k.submissionId = null + k + } + } + + /** A faulty status servlet that explodes. */ + class ExplodingStatusServlet extends StatusRequestServlet(masterActor, masterConf) { + private def explode: Int = 1 / 0 + protected override def handleStatus(submissionId: String): SubmissionStatusResponse = { + val s = super.handleStatus(submissionId) + s.workerId = explode.toString + s + } + } +} diff --git a/core/src/test/scala/org/apache/spark/deploy/rest/SubmitRestProtocolSuite.scala b/core/src/test/scala/org/apache/spark/deploy/rest/SubmitRestProtocolSuite.scala new file mode 100644 index 0000000000000..1d64ec201e647 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/deploy/rest/SubmitRestProtocolSuite.scala @@ -0,0 +1,324 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.rest + +import java.lang.Boolean +import java.lang.Integer + +import org.json4s.jackson.JsonMethods._ +import org.scalatest.FunSuite + +import org.apache.spark.SparkConf + +/** + * Tests for the REST application submission protocol. + */ +class SubmitRestProtocolSuite extends FunSuite { + + test("validate") { + val request = new DummyRequest + intercept[SubmitRestProtocolException] { request.validate() } // missing everything + request.clientSparkVersion = "1.2.3" + intercept[SubmitRestProtocolException] { request.validate() } // missing name and age + request.name = "something" + intercept[SubmitRestProtocolException] { request.validate() } // missing only age + request.age = 2 + intercept[SubmitRestProtocolException] { request.validate() } // age too low + request.age = 10 + request.validate() // everything is set properly + request.clientSparkVersion = null + intercept[SubmitRestProtocolException] { request.validate() } // missing only Spark version + request.clientSparkVersion = "1.2.3" + request.name = null + intercept[SubmitRestProtocolException] { request.validate() } // missing only name + request.message = "not-setting-name" + intercept[SubmitRestProtocolException] { request.validate() } // still missing name + } + + test("request to and from JSON") { + val request = new DummyRequest + intercept[SubmitRestProtocolException] { request.toJson } // implicit validation + request.clientSparkVersion = "1.2.3" + request.active = true + request.age = 25 + request.name = "jung" + val json = request.toJson + assertJsonEquals(json, dummyRequestJson) + val newRequest = SubmitRestProtocolMessage.fromJson(json, classOf[DummyRequest]) + assert(newRequest.clientSparkVersion === "1.2.3") + assert(newRequest.clientSparkVersion === "1.2.3") + assert(newRequest.active) + assert(newRequest.age === 25) + assert(newRequest.name === "jung") + assert(newRequest.message === null) + } + + test("response to and from JSON") { + val response = new DummyResponse + response.serverSparkVersion = "3.3.4" + response.success = true + val json = response.toJson + assertJsonEquals(json, dummyResponseJson) + val newResponse = SubmitRestProtocolMessage.fromJson(json, classOf[DummyResponse]) + assert(newResponse.serverSparkVersion === "3.3.4") + assert(newResponse.serverSparkVersion === "3.3.4") + assert(newResponse.success) + assert(newResponse.message === null) + } + + test("CreateSubmissionRequest") { + val message = new CreateSubmissionRequest + intercept[SubmitRestProtocolException] { message.validate() } + message.clientSparkVersion = "1.2.3" + message.appResource = "honey-walnut-cherry.jar" + message.mainClass = "org.apache.spark.examples.SparkPie" + val conf = new SparkConf(false) + conf.set("spark.app.name", "SparkPie") + message.sparkProperties = conf.getAll.toMap + message.validate() + // optional fields + conf.set("spark.jars", "mayonnaise.jar,ketchup.jar") + conf.set("spark.files", "fireball.png") + conf.set("spark.driver.memory", "512m") + conf.set("spark.driver.cores", "180") + conf.set("spark.driver.extraJavaOptions", " -Dslices=5 -Dcolor=mostly_red") + conf.set("spark.driver.extraClassPath", "food-coloring.jar") + conf.set("spark.driver.extraLibraryPath", "pickle.jar") + conf.set("spark.driver.supervise", "false") + conf.set("spark.executor.memory", "256m") + conf.set("spark.cores.max", "10000") + message.sparkProperties = conf.getAll.toMap + message.appArgs = Array("two slices", "a hint of cinnamon") + message.environmentVariables = Map("PATH" -> "/dev/null") + message.validate() + // bad fields + var badConf = conf.clone().set("spark.driver.cores", "one hundred feet") + message.sparkProperties = badConf.getAll.toMap + intercept[SubmitRestProtocolException] { message.validate() } + badConf = conf.clone().set("spark.driver.supervise", "nope, never") + message.sparkProperties = badConf.getAll.toMap + intercept[SubmitRestProtocolException] { message.validate() } + badConf = conf.clone().set("spark.cores.max", "two men") + message.sparkProperties = badConf.getAll.toMap + intercept[SubmitRestProtocolException] { message.validate() } + message.sparkProperties = conf.getAll.toMap + // test JSON + val json = message.toJson + assertJsonEquals(json, submitDriverRequestJson) + val newMessage = SubmitRestProtocolMessage.fromJson(json, classOf[CreateSubmissionRequest]) + assert(newMessage.clientSparkVersion === "1.2.3") + assert(newMessage.appResource === "honey-walnut-cherry.jar") + assert(newMessage.mainClass === "org.apache.spark.examples.SparkPie") + assert(newMessage.sparkProperties("spark.app.name") === "SparkPie") + assert(newMessage.sparkProperties("spark.jars") === "mayonnaise.jar,ketchup.jar") + assert(newMessage.sparkProperties("spark.files") === "fireball.png") + assert(newMessage.sparkProperties("spark.driver.memory") === "512m") + assert(newMessage.sparkProperties("spark.driver.cores") === "180") + assert(newMessage.sparkProperties("spark.driver.extraJavaOptions") === " -Dslices=5 -Dcolor=mostly_red") + assert(newMessage.sparkProperties("spark.driver.extraClassPath") === "food-coloring.jar") + assert(newMessage.sparkProperties("spark.driver.extraLibraryPath") === "pickle.jar") + assert(newMessage.sparkProperties("spark.driver.supervise") === "false") + assert(newMessage.sparkProperties("spark.executor.memory") === "256m") + assert(newMessage.sparkProperties("spark.cores.max") === "10000") + assert(newMessage.appArgs === message.appArgs) + assert(newMessage.sparkProperties === message.sparkProperties) + assert(newMessage.environmentVariables === message.environmentVariables) + } + + test("CreateSubmissionResponse") { + val message = new CreateSubmissionResponse + intercept[SubmitRestProtocolException] { message.validate() } + message.serverSparkVersion = "1.2.3" + message.submissionId = "driver_123" + message.success = true + message.validate() + // test JSON + val json = message.toJson + assertJsonEquals(json, submitDriverResponseJson) + val newMessage = SubmitRestProtocolMessage.fromJson(json, classOf[CreateSubmissionResponse]) + assert(newMessage.serverSparkVersion === "1.2.3") + assert(newMessage.submissionId === "driver_123") + assert(newMessage.success) + } + + test("KillSubmissionResponse") { + val message = new KillSubmissionResponse + intercept[SubmitRestProtocolException] { message.validate() } + message.serverSparkVersion = "1.2.3" + message.submissionId = "driver_123" + message.success = true + message.validate() + // test JSON + val json = message.toJson + assertJsonEquals(json, killDriverResponseJson) + val newMessage = SubmitRestProtocolMessage.fromJson(json, classOf[KillSubmissionResponse]) + assert(newMessage.serverSparkVersion === "1.2.3") + assert(newMessage.submissionId === "driver_123") + assert(newMessage.success) + } + + test("SubmissionStatusResponse") { + val message = new SubmissionStatusResponse + intercept[SubmitRestProtocolException] { message.validate() } + message.serverSparkVersion = "1.2.3" + message.submissionId = "driver_123" + message.success = true + message.validate() + // optional fields + message.driverState = "RUNNING" + message.workerId = "worker_123" + message.workerHostPort = "1.2.3.4:7780" + // test JSON + val json = message.toJson + assertJsonEquals(json, driverStatusResponseJson) + val newMessage = SubmitRestProtocolMessage.fromJson(json, classOf[SubmissionStatusResponse]) + assert(newMessage.serverSparkVersion === "1.2.3") + assert(newMessage.submissionId === "driver_123") + assert(newMessage.driverState === "RUNNING") + assert(newMessage.success) + assert(newMessage.workerId === "worker_123") + assert(newMessage.workerHostPort === "1.2.3.4:7780") + } + + test("ErrorResponse") { + val message = new ErrorResponse + intercept[SubmitRestProtocolException] { message.validate() } + message.serverSparkVersion = "1.2.3" + message.message = "Field not found in submit request: X" + message.validate() + // test JSON + val json = message.toJson + assertJsonEquals(json, errorJson) + val newMessage = SubmitRestProtocolMessage.fromJson(json, classOf[ErrorResponse]) + assert(newMessage.serverSparkVersion === "1.2.3") + assert(newMessage.message === "Field not found in submit request: X") + } + + private val dummyRequestJson = + """ + |{ + | "action" : "DummyRequest", + | "active" : true, + | "age" : 25, + | "clientSparkVersion" : "1.2.3", + | "name" : "jung" + |} + """.stripMargin + + private val dummyResponseJson = + """ + |{ + | "action" : "DummyResponse", + | "serverSparkVersion" : "3.3.4", + | "success": true + |} + """.stripMargin + + private val submitDriverRequestJson = + """ + |{ + | "action" : "CreateSubmissionRequest", + | "appArgs" : [ "two slices", "a hint of cinnamon" ], + | "appResource" : "honey-walnut-cherry.jar", + | "clientSparkVersion" : "1.2.3", + | "environmentVariables" : { + | "PATH" : "/dev/null" + | }, + | "mainClass" : "org.apache.spark.examples.SparkPie", + | "sparkProperties" : { + | "spark.driver.extraLibraryPath" : "pickle.jar", + | "spark.jars" : "mayonnaise.jar,ketchup.jar", + | "spark.driver.supervise" : "false", + | "spark.app.name" : "SparkPie", + | "spark.cores.max" : "10000", + | "spark.driver.memory" : "512m", + | "spark.files" : "fireball.png", + | "spark.driver.cores" : "180", + | "spark.driver.extraJavaOptions" : " -Dslices=5 -Dcolor=mostly_red", + | "spark.executor.memory" : "256m", + | "spark.driver.extraClassPath" : "food-coloring.jar" + | } + |} + """.stripMargin + + private val submitDriverResponseJson = + """ + |{ + | "action" : "CreateSubmissionResponse", + | "serverSparkVersion" : "1.2.3", + | "submissionId" : "driver_123", + | "success" : true + |} + """.stripMargin + + private val killDriverResponseJson = + """ + |{ + | "action" : "KillSubmissionResponse", + | "serverSparkVersion" : "1.2.3", + | "submissionId" : "driver_123", + | "success" : true + |} + """.stripMargin + + private val driverStatusResponseJson = + """ + |{ + | "action" : "SubmissionStatusResponse", + | "driverState" : "RUNNING", + | "serverSparkVersion" : "1.2.3", + | "submissionId" : "driver_123", + | "success" : true, + | "workerHostPort" : "1.2.3.4:7780", + | "workerId" : "worker_123" + |} + """.stripMargin + + private val errorJson = + """ + |{ + | "action" : "ErrorResponse", + | "message" : "Field not found in submit request: X", + | "serverSparkVersion" : "1.2.3" + |} + """.stripMargin + + /** Assert that the contents in the two JSON strings are equal after ignoring whitespace. */ + private def assertJsonEquals(jsonString1: String, jsonString2: String): Unit = { + val trimmedJson1 = jsonString1.trim + val trimmedJson2 = jsonString2.trim + val json1 = compact(render(parse(trimmedJson1))) + val json2 = compact(render(parse(trimmedJson2))) + // Put this on a separate line to avoid printing comparison twice when test fails + val equals = json1 == json2 + assert(equals, "\"[%s]\" did not equal \"[%s]\"".format(trimmedJson1, trimmedJson2)) + } +} + +private class DummyResponse extends SubmitRestProtocolResponse +private class DummyRequest extends SubmitRestProtocolRequest { + var active: Boolean = null + var age: Integer = null + var name: String = null + protected override def doValidate(): Unit = { + super.doValidate() + assertFieldIsSet(name, "name") + assertFieldIsSet(age, "age") + assert(age > 5, "Not old enough!") + } +} diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/DriverRunnerTest.scala b/core/src/test/scala/org/apache/spark/deploy/worker/DriverRunnerTest.scala index b6f4411e0587a..aa6e4874cecde 100644 --- a/core/src/test/scala/org/apache/spark/deploy/worker/DriverRunnerTest.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/DriverRunnerTest.scala @@ -27,6 +27,7 @@ import org.scalatest.FunSuite import org.apache.spark.SparkConf import org.apache.spark.deploy.{Command, DriverDescription} +import org.apache.spark.util.Clock class DriverRunnerTest extends FunSuite { private def createDriverRunner() = { @@ -129,7 +130,7 @@ class DriverRunnerTest extends FunSuite { .thenReturn(-1) // fail 3 .thenReturn(-1) // fail 4 .thenReturn(0) // success - when(clock.currentTimeMillis()) + when(clock.getTimeMillis()) .thenReturn(0).thenReturn(1000) // fail 1 (short) .thenReturn(1000).thenReturn(2000) // fail 2 (short) .thenReturn(2000).thenReturn(10000) // fail 3 (long) diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala b/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala index 6f233d7cf97aa..76511699e5ac5 100644 --- a/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala @@ -32,7 +32,7 @@ class ExecutorRunnerTest extends FunSuite { val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!")) val appDesc = new ApplicationDescription("app name", Some(8), 500, Command("foo", Seq(appId), Map(), Seq(), Seq(), Seq()), "appUiUrl") - val er = new ExecutorRunner(appId, 1, appDesc, 8, 500, null, "blah", "worker321", + val er = new ExecutorRunner(appId, 1, appDesc, 8, 500, null, "blah", "worker321", 123, new File(sparkHome), new File("ooga"), "blah", new SparkConf, Seq("localDir"), ExecutorState.RUNNING) val builder = CommandUtils.buildProcessBuilder(appDesc.command, 512, sparkHome, er.substituteVariables) diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerArgumentsTest.scala b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerArgumentsTest.scala index 1a28a9a187cd7..372d7aa453008 100644 --- a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerArgumentsTest.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerArgumentsTest.scala @@ -43,7 +43,7 @@ class WorkerArgumentsTest extends FunSuite { } override def clone: SparkConf = { - new MySparkConf().setAll(settings) + new MySparkConf().setAll(getAll) } } val conf = new MySparkConf() @@ -62,7 +62,7 @@ class WorkerArgumentsTest extends FunSuite { } override def clone: SparkConf = { - new MySparkConf().setAll(settings) + new MySparkConf().setAll(getAll) } } val conf = new MySparkConf() diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerSuite.scala new file mode 100644 index 0000000000000..84e2fd7ad936d --- /dev/null +++ b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerSuite.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.worker + +import org.apache.spark.SparkConf +import org.apache.spark.deploy.Command + +import org.scalatest.{Matchers, FunSuite} + +class WorkerSuite extends FunSuite with Matchers { + + def cmd(javaOpts: String*) = Command("", Seq.empty, Map.empty, Seq.empty, Seq.empty, Seq(javaOpts:_*)) + def conf(opts: (String, String)*) = new SparkConf(loadDefaults = false).setAll(opts) + + test("test isUseLocalNodeSSLConfig") { + Worker.isUseLocalNodeSSLConfig(cmd("-Dasdf=dfgh")) shouldBe false + Worker.isUseLocalNodeSSLConfig(cmd("-Dspark.ssl.useNodeLocalConf=true")) shouldBe true + Worker.isUseLocalNodeSSLConfig(cmd("-Dspark.ssl.useNodeLocalConf=false")) shouldBe false + Worker.isUseLocalNodeSSLConfig(cmd("-Dspark.ssl.useNodeLocalConf=")) shouldBe false + } + + test("test maybeUpdateSSLSettings") { + Worker.maybeUpdateSSLSettings( + cmd("-Dasdf=dfgh", "-Dspark.ssl.opt1=x"), + conf("spark.ssl.opt1" -> "y", "spark.ssl.opt2" -> "z")) + .javaOpts should contain theSameElementsInOrderAs Seq( + "-Dasdf=dfgh", "-Dspark.ssl.opt1=x") + + Worker.maybeUpdateSSLSettings( + cmd("-Dspark.ssl.useNodeLocalConf=false", "-Dspark.ssl.opt1=x"), + conf("spark.ssl.opt1" -> "y", "spark.ssl.opt2" -> "z")) + .javaOpts should contain theSameElementsInOrderAs Seq( + "-Dspark.ssl.useNodeLocalConf=false", "-Dspark.ssl.opt1=x") + + Worker.maybeUpdateSSLSettings( + cmd("-Dspark.ssl.useNodeLocalConf=true", "-Dspark.ssl.opt1=x"), + conf("spark.ssl.opt1" -> "y", "spark.ssl.opt2" -> "z")) + .javaOpts should contain theSameElementsAs Seq( + "-Dspark.ssl.useNodeLocalConf=true", "-Dspark.ssl.opt1=y", "-Dspark.ssl.opt2=z") + + } +} diff --git a/core/src/test/scala/org/apache/spark/executor/TaskMetricsSuite.scala b/core/src/test/scala/org/apache/spark/executor/TaskMetricsSuite.scala new file mode 100644 index 0000000000000..326e203afe136 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/executor/TaskMetricsSuite.scala @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.executor + +import org.scalatest.FunSuite + +class TaskMetricsSuite extends FunSuite { + test("[SPARK-5701] updateShuffleReadMetrics: ShuffleReadMetrics not added when no shuffle deps") { + val taskMetrics = new TaskMetrics() + taskMetrics.updateShuffleReadMetrics() + assert(taskMetrics.shuffleReadMetrics.isEmpty) + } +} diff --git a/core/src/test/scala/org/apache/spark/input/WholeTextFileRecordReaderSuite.scala b/core/src/test/scala/org/apache/spark/input/WholeTextFileRecordReaderSuite.scala index 98b0a16ce88ba..2e58c159a2ed8 100644 --- a/core/src/test/scala/org/apache/spark/input/WholeTextFileRecordReaderSuite.scala +++ b/core/src/test/scala/org/apache/spark/input/WholeTextFileRecordReaderSuite.scala @@ -28,7 +28,7 @@ import org.scalatest.FunSuite import org.apache.hadoop.io.Text -import org.apache.spark.SparkContext +import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.util.Utils import org.apache.hadoop.io.compress.{DefaultCodec, CompressionCodecFactory, GzipCodec} @@ -42,7 +42,15 @@ class WholeTextFileRecordReaderSuite extends FunSuite with BeforeAndAfterAll { private var factory: CompressionCodecFactory = _ override def beforeAll() { - sc = new SparkContext("local", "test") + // Hadoop's FileSystem caching does not use the Configuration as part of its cache key, which + // can cause Filesystem.get(Configuration) to return a cached instance created with a different + // configuration than the one passed to get() (see HADOOP-8490 for more details). This caused + // hard-to-reproduce test failures, since any suites that were run after this one would inherit + // the new value of "fs.local.block.size" (see SPARK-5227 and SPARK-5679). To work around this, + // we disable FileSystem caching in this suite. + val conf = new SparkConf().set("spark.hadoop.fs.file.impl.disable.cache", "true") + + sc = new SparkContext("local", "test", conf) // Set the block size of local file system to test whether files are split right or not. sc.hadoopConfiguration.setLong("fs.local.block.size", 32) diff --git a/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala b/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala index 10a39990f80ce..78fa98a3b9065 100644 --- a/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala +++ b/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala @@ -21,35 +21,46 @@ import java.io.{File, FileWriter, PrintWriter} import scala.collection.mutable.ArrayBuffer -import org.scalatest.FunSuite - +import org.apache.commons.lang.math.RandomUtils import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.io.{LongWritable, Text} -import org.apache.hadoop.mapreduce.lib.input.{TextInputFormat => NewTextInputFormat} +import org.apache.hadoop.mapred.lib.{CombineFileInputFormat => OldCombineFileInputFormat, + CombineFileRecordReader => OldCombineFileRecordReader, CombineFileSplit => OldCombineFileSplit} +import org.apache.hadoop.mapred.{JobConf, Reporter, FileSplit => OldFileSplit, + InputSplit => OldInputSplit, LineRecordReader => OldLineRecordReader, + RecordReader => OldRecordReader, TextInputFormat => OldTextInputFormat} +import org.apache.hadoop.mapreduce.lib.input.{CombineFileInputFormat => NewCombineFileInputFormat, + CombineFileRecordReader => NewCombineFileRecordReader, CombineFileSplit => NewCombineFileSplit, + FileSplit => NewFileSplit, TextInputFormat => NewTextInputFormat} +import org.apache.hadoop.mapreduce.lib.output.{TextOutputFormat => NewTextOutputFormat} +import org.apache.hadoop.mapreduce.{TaskAttemptContext, InputSplit => NewInputSplit, + RecordReader => NewRecordReader} +import org.scalatest.{BeforeAndAfter, FunSuite} import org.apache.spark.SharedSparkContext import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd} import org.apache.spark.util.Utils -class InputOutputMetricsSuite extends FunSuite with SharedSparkContext { +class InputOutputMetricsSuite extends FunSuite with SharedSparkContext + with BeforeAndAfter { @transient var tmpDir: File = _ @transient var tmpFile: File = _ @transient var tmpFilePath: String = _ + @transient val numRecords: Int = 100000 + @transient val numBuckets: Int = 10 - override def beforeAll() { - super.beforeAll() - + before { tmpDir = Utils.createTempDir() val testTempDir = new File(tmpDir, "test") testTempDir.mkdir() tmpFile = new File(testTempDir, getClass.getSimpleName + ".txt") val pw = new PrintWriter(new FileWriter(tmpFile)) - for (x <- 1 to 1000000) { - pw.println("s") + for (x <- 1 to numRecords) { + pw.println(RandomUtils.nextInt(numBuckets)) } pw.close() @@ -57,8 +68,7 @@ class InputOutputMetricsSuite extends FunSuite with SharedSparkContext { tmpFilePath = "file://" + tmpFile.getAbsolutePath } - override def afterAll() { - super.afterAll() + after { Utils.deleteRecursively(tmpDir) } @@ -146,6 +156,101 @@ class InputOutputMetricsSuite extends FunSuite with SharedSparkContext { assert(bytesRead >= tmpFile.length()) } + test("input metrics on records read - simple") { + val records = runAndReturnRecordsRead { + sc.textFile(tmpFilePath, 4).count() + } + assert(records == numRecords) + } + + test("input metrics on records read - more stages") { + val records = runAndReturnRecordsRead { + sc.textFile(tmpFilePath, 4) + .map(key => (key.length, 1)) + .reduceByKey(_ + _) + .count() + } + assert(records == numRecords) + } + + test("input metrics on records - New Hadoop API") { + val records = runAndReturnRecordsRead { + sc.newAPIHadoopFile(tmpFilePath, classOf[NewTextInputFormat], classOf[LongWritable], + classOf[Text]).count() + } + assert(records == numRecords) + } + + test("input metrics on recordsd read with cache") { + // prime the cache manager + val rdd = sc.textFile(tmpFilePath, 4).cache() + rdd.collect() + + val records = runAndReturnRecordsRead { + rdd.count() + } + + assert(records == numRecords) + } + + test("shuffle records read metrics") { + val recordsRead = runAndReturnShuffleRecordsRead { + sc.textFile(tmpFilePath, 4) + .map(key => (key, 1)) + .groupByKey() + .collect() + } + assert(recordsRead == numRecords) + } + + test("shuffle records written metrics") { + val recordsWritten = runAndReturnShuffleRecordsWritten { + sc.textFile(tmpFilePath, 4) + .map(key => (key, 1)) + .groupByKey() + .collect() + } + assert(recordsWritten == numRecords) + } + + /** + * Tests the metrics from end to end. + * 1) reading a hadoop file + * 2) shuffle and writing to a hadoop file. + * 3) writing to hadoop file. + */ + test("input read/write and shuffle read/write metrics all line up") { + var inputRead = 0L + var outputWritten = 0L + var shuffleRead = 0L + var shuffleWritten = 0L + sc.addSparkListener(new SparkListener() { + override def onTaskEnd(taskEnd: SparkListenerTaskEnd) { + val metrics = taskEnd.taskMetrics + metrics.inputMetrics.foreach(inputRead += _.recordsRead) + metrics.outputMetrics.foreach(outputWritten += _.recordsWritten) + metrics.shuffleReadMetrics.foreach(shuffleRead += _.recordsRead) + metrics.shuffleWriteMetrics.foreach(shuffleWritten += _.shuffleRecordsWritten) + } + }) + + val tmpFile = new File(tmpDir, getClass.getSimpleName) + + sc.textFile(tmpFilePath, 4) + .map(key => (key, 1)) + .reduceByKey(_+_) + .saveAsTextFile("file://" + tmpFile.getAbsolutePath) + + sc.listenerBus.waitUntilEmpty(500) + assert(inputRead == numRecords) + + // Only supported on newer Hadoop + if (SparkHadoopUtil.get.getFSBytesWrittenOnThreadCallback().isDefined) { + assert(outputWritten == numBuckets) + } + assert(shuffleRead == shuffleWritten) + } + test("input metrics with interleaved reads") { val numPartitions = 2 val cartVector = 0 to 9 @@ -184,25 +289,73 @@ class InputOutputMetricsSuite extends FunSuite with SharedSparkContext { assert(cartesianBytes == firstSize * numPartitions + (cartVector.length * secondSize)) } - private def runAndReturnBytesRead(job : => Unit): Long = { - val taskBytesRead = new ArrayBuffer[Long]() + private def runAndReturnBytesRead(job: => Unit): Long = { + runAndReturnMetrics(job, _.taskMetrics.inputMetrics.map(_.bytesRead)) + } + + private def runAndReturnRecordsRead(job: => Unit): Long = { + runAndReturnMetrics(job, _.taskMetrics.inputMetrics.map(_.recordsRead)) + } + + private def runAndReturnRecordsWritten(job: => Unit): Long = { + runAndReturnMetrics(job, _.taskMetrics.outputMetrics.map(_.recordsWritten)) + } + + private def runAndReturnShuffleRecordsRead(job: => Unit): Long = { + runAndReturnMetrics(job, _.taskMetrics.shuffleReadMetrics.map(_.recordsRead)) + } + + private def runAndReturnShuffleRecordsWritten(job: => Unit): Long = { + runAndReturnMetrics(job, _.taskMetrics.shuffleWriteMetrics.map(_.shuffleRecordsWritten)) + } + + private def runAndReturnMetrics(job: => Unit, + collector: (SparkListenerTaskEnd) => Option[Long]): Long = { + val taskMetrics = new ArrayBuffer[Long]() sc.addSparkListener(new SparkListener() { override def onTaskEnd(taskEnd: SparkListenerTaskEnd) { - taskBytesRead += taskEnd.taskMetrics.inputMetrics.get.bytesRead + collector(taskEnd).foreach(taskMetrics += _) } }) job sc.listenerBus.waitUntilEmpty(500) - taskBytesRead.sum + taskMetrics.sum + } + + test("output metrics on records written") { + // Only supported on newer Hadoop + if (SparkHadoopUtil.get.getFSBytesWrittenOnThreadCallback().isDefined) { + val file = new File(tmpDir, getClass.getSimpleName) + val filePath = "file://" + file.getAbsolutePath + + val records = runAndReturnRecordsWritten { + sc.parallelize(1 to numRecords).saveAsTextFile(filePath) + } + assert(records == numRecords) + } + } + + test("output metrics on records written - new Hadoop API") { + // Only supported on newer Hadoop + if (SparkHadoopUtil.get.getFSBytesWrittenOnThreadCallback().isDefined) { + val file = new File(tmpDir, getClass.getSimpleName) + val filePath = "file://" + file.getAbsolutePath + + val records = runAndReturnRecordsWritten { + sc.parallelize(1 to numRecords).map(key => (key.toString, key.toString)) + .saveAsNewAPIHadoopFile[NewTextOutputFormat[String, String]](filePath) + } + assert(records == numRecords) + } } test("output metrics when writing text file") { val fs = FileSystem.getLocal(new Configuration()) val outPath = new Path(fs.getWorkingDirectory, "outdir") - if (SparkHadoopUtil.get.getFSBytesWrittenOnThreadCallback(outPath, fs.getConf).isDefined) { + if (SparkHadoopUtil.get.getFSBytesWrittenOnThreadCallback().isDefined) { val taskBytesWritten = new ArrayBuffer[Long]() sc.addSparkListener(new SparkListener() { override def onTaskEnd(taskEnd: SparkListenerTaskEnd) { @@ -225,4 +378,88 @@ class InputOutputMetricsSuite extends FunSuite with SharedSparkContext { } } } + + test("input metrics with old CombineFileInputFormat") { + val bytesRead = runAndReturnBytesRead { + sc.hadoopFile(tmpFilePath, classOf[OldCombineTextInputFormat], classOf[LongWritable], + classOf[Text], 2).count() + } + assert(bytesRead >= tmpFile.length()) + } + + test("input metrics with new CombineFileInputFormat") { + val bytesRead = runAndReturnBytesRead { + sc.newAPIHadoopFile(tmpFilePath, classOf[NewCombineTextInputFormat], classOf[LongWritable], + classOf[Text], new Configuration()).count() + } + assert(bytesRead >= tmpFile.length()) + } +} + +/** + * Hadoop 2 has a version of this, but we can't use it for backwards compatibility + */ +class OldCombineTextInputFormat extends OldCombineFileInputFormat[LongWritable, Text] { + override def getRecordReader(split: OldInputSplit, conf: JobConf, reporter: Reporter) + : OldRecordReader[LongWritable, Text] = { + new OldCombineFileRecordReader[LongWritable, Text](conf, + split.asInstanceOf[OldCombineFileSplit], reporter, classOf[OldCombineTextRecordReaderWrapper] + .asInstanceOf[Class[OldRecordReader[LongWritable, Text]]]) + } +} + +class OldCombineTextRecordReaderWrapper( + split: OldCombineFileSplit, + conf: Configuration, + reporter: Reporter, + idx: Integer) extends OldRecordReader[LongWritable, Text] { + + val fileSplit = new OldFileSplit(split.getPath(idx), + split.getOffset(idx), + split.getLength(idx), + split.getLocations()) + + val delegate: OldLineRecordReader = new OldTextInputFormat().getRecordReader(fileSplit, + conf.asInstanceOf[JobConf], reporter).asInstanceOf[OldLineRecordReader] + + override def next(key: LongWritable, value: Text): Boolean = delegate.next(key, value) + override def createKey(): LongWritable = delegate.createKey() + override def createValue(): Text = delegate.createValue() + override def getPos(): Long = delegate.getPos + override def close(): Unit = delegate.close() + override def getProgress(): Float = delegate.getProgress +} + +/** + * Hadoop 2 has a version of this, but we can't use it for backwards compatibility + */ +class NewCombineTextInputFormat extends NewCombineFileInputFormat[LongWritable,Text] { + def createRecordReader(split: NewInputSplit, context: TaskAttemptContext) + : NewRecordReader[LongWritable, Text] = { + new NewCombineFileRecordReader[LongWritable,Text](split.asInstanceOf[NewCombineFileSplit], + context, classOf[NewCombineTextRecordReaderWrapper]) + } +} + +class NewCombineTextRecordReaderWrapper( + split: NewCombineFileSplit, + context: TaskAttemptContext, + idx: Integer) extends NewRecordReader[LongWritable, Text] { + + val fileSplit = new NewFileSplit(split.getPath(idx), + split.getOffset(idx), + split.getLength(idx), + split.getLocations()) + + val delegate = new NewTextInputFormat().createRecordReader(fileSplit, context) + + override def initialize(split: NewInputSplit, context: TaskAttemptContext): Unit = { + delegate.initialize(fileSplit, context) + } + + override def nextKeyValue(): Boolean = delegate.nextKeyValue() + override def getCurrentKey(): LongWritable = delegate.getCurrentKey + override def getCurrentValue(): Text = delegate.getCurrentValue + override def getProgress(): Float = delegate.getProgress + override def close(): Unit = delegate.close() } diff --git a/core/src/test/scala/org/apache/spark/metrics/MetricsConfigSuite.scala b/core/src/test/scala/org/apache/spark/metrics/MetricsConfigSuite.scala index 1a9ce8c607dcd..37e528435aa5d 100644 --- a/core/src/test/scala/org/apache/spark/metrics/MetricsConfigSuite.scala +++ b/core/src/test/scala/org/apache/spark/metrics/MetricsConfigSuite.scala @@ -27,7 +27,7 @@ class MetricsConfigSuite extends FunSuite with BeforeAndAfter { } test("MetricsConfig with default properties") { - val conf = new MetricsConfig(Option("dummy-file")) + val conf = new MetricsConfig(None) conf.initialize() assert(conf.properties.size() === 4) diff --git a/core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala index de306533752c1..4cd0f97368ca3 100644 --- a/core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala @@ -33,6 +33,9 @@ class DoubleRDDSuite extends FunSuite with SharedSparkContext { val expectedHistogramResults = Array(0) assert(histogramResults === expectedHistogramResults) assert(histogramResults2 === expectedHistogramResults) + val emptyRDD: RDD[Double] = sc.emptyRDD + assert(emptyRDD.histogram(buckets) === expectedHistogramResults) + assert(emptyRDD.histogram(buckets, true) === expectedHistogramResults) } test("WorksWithOutOfRangeWithOneBucket") { diff --git a/core/src/test/scala/org/apache/spark/rdd/JdbcRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/JdbcRDDSuite.scala index 6138d0bbd57f6..0dc59888f7304 100644 --- a/core/src/test/scala/org/apache/spark/rdd/JdbcRDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/JdbcRDDSuite.scala @@ -29,22 +29,42 @@ class JdbcRDDSuite extends FunSuite with BeforeAndAfter with LocalSparkContext { Class.forName("org.apache.derby.jdbc.EmbeddedDriver") val conn = DriverManager.getConnection("jdbc:derby:target/JdbcRDDSuiteDb;create=true") try { - val create = conn.createStatement - create.execute(""" - CREATE TABLE FOO( - ID INTEGER NOT NULL GENERATED ALWAYS AS IDENTITY (START WITH 1, INCREMENT BY 1), - DATA INTEGER - )""") - create.close() - val insert = conn.prepareStatement("INSERT INTO FOO(DATA) VALUES(?)") - (1 to 100).foreach { i => - insert.setInt(1, i * 2) - insert.executeUpdate + + try { + val create = conn.createStatement + create.execute(""" + CREATE TABLE FOO( + ID INTEGER NOT NULL GENERATED ALWAYS AS IDENTITY (START WITH 1, INCREMENT BY 1), + DATA INTEGER + )""") + create.close() + val insert = conn.prepareStatement("INSERT INTO FOO(DATA) VALUES(?)") + (1 to 100).foreach { i => + insert.setInt(1, i * 2) + insert.executeUpdate + } + insert.close() + } catch { + case e: SQLException if e.getSQLState == "X0Y32" => + // table exists } - insert.close() - } catch { - case e: SQLException if e.getSQLState == "X0Y32" => + + try { + val create = conn.createStatement + create.execute("CREATE TABLE BIGINT_TEST(ID BIGINT NOT NULL, DATA INTEGER)") + create.close() + val insert = conn.prepareStatement("INSERT INTO BIGINT_TEST VALUES(?,?)") + (1 to 100).foreach { i => + insert.setLong(1, 100000000000000000L + 4000000000000000L * i) + insert.setInt(2, i) + insert.executeUpdate + } + insert.close() + } catch { + case e: SQLException if e.getSQLState == "X0Y32" => // table exists + } + } finally { conn.close() } @@ -62,6 +82,18 @@ class JdbcRDDSuite extends FunSuite with BeforeAndAfter with LocalSparkContext { assert(rdd.count === 100) assert(rdd.reduce(_+_) === 10100) } + + test("large id overflow") { + sc = new SparkContext("local", "test") + val rdd = new JdbcRDD( + sc, + () => { DriverManager.getConnection("jdbc:derby:target/JdbcRDDSuiteDb") }, + "SELECT DATA FROM BIGINT_TEST WHERE ? <= ID AND ID <= ?", + 1131544775L, 567279358897692673L, 20, + (r: ResultSet) => { r.getInt(1) } ).cache() + assert(rdd.count === 100) + assert(rdd.reduce(_+_) === 5050) + } after { try { diff --git a/core/src/test/scala/org/apache/spark/rdd/ParallelCollectionSplitSuite.scala b/core/src/test/scala/org/apache/spark/rdd/ParallelCollectionSplitSuite.scala index 1b112f1a41ca9..cd193ae4f5238 100644 --- a/core/src/test/scala/org/apache/spark/rdd/ParallelCollectionSplitSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/ParallelCollectionSplitSuite.scala @@ -76,6 +76,7 @@ class ParallelCollectionSplitSuite extends FunSuite with Checkers { assert(slices(0).mkString(",") === (0 to 32).mkString(",")) assert(slices(1).mkString(",") === (33 to 66).mkString(",")) assert(slices(2).mkString(",") === (67 to 100).mkString(",")) + assert(slices(2).isInstanceOf[Range.Inclusive]) } test("empty data") { @@ -227,4 +228,28 @@ class ParallelCollectionSplitSuite extends FunSuite with Checkers { assert(slices.map(_.size).reduceLeft(_+_) === 100) assert(slices.forall(_.isInstanceOf[NumericRange[_]])) } + + test("inclusive ranges with Int.MaxValue and Int.MinValue") { + val data1 = 1 to Int.MaxValue + val slices1 = ParallelCollectionRDD.slice(data1, 3) + assert(slices1.size === 3) + assert(slices1.map(_.size).sum === Int.MaxValue) + assert(slices1(2).isInstanceOf[Range.Inclusive]) + val data2 = -2 to Int.MinValue by -1 + val slices2 = ParallelCollectionRDD.slice(data2, 3) + assert(slices2.size == 3) + assert(slices2.map(_.size).sum === Int.MaxValue) + assert(slices2(2).isInstanceOf[Range.Inclusive]) + } + + test("empty ranges with Int.MaxValue and Int.MinValue") { + val data1 = Int.MaxValue until Int.MaxValue + val slices1 = ParallelCollectionRDD.slice(data1, 5) + assert(slices1.size === 5) + for (i <- 0 until 5) assert(slices1(i).size === 0) + val data2 = Int.MaxValue until Int.MaxValue + val slices2 = ParallelCollectionRDD.slice(data2, 5) + assert(slices2.size === 5) + for (i <- 0 until 5) assert(slices2(i).size === 0) + } } diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index 0deb9b18b8688..bede1ffb3e2d0 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -52,6 +52,7 @@ class RDDSuite extends FunSuite with SharedSparkContext { assert(nums.glom().map(_.toList).collect().toList === List(List(1, 2), List(3, 4))) assert(nums.collect({ case i if i >= 3 => i.toString }).collect().toList === List("3", "4")) assert(nums.keyBy(_.toString).collect().toList === List(("1", 1), ("2", 2), ("3", 3), ("4", 4))) + assert(!nums.isEmpty()) assert(nums.max() === 4) assert(nums.min() === 1) val partitionSums = nums.mapPartitions(iter => Iterator(iter.reduceLeft(_ + _))) @@ -156,6 +157,24 @@ class RDDSuite extends FunSuite with SharedSparkContext { assert(result.toSet === Set(("a", 6), ("b", 2), ("c", 5))) } + test("treeAggregate") { + val rdd = sc.makeRDD(-1000 until 1000, 10) + def seqOp = (c: Long, x: Int) => c + x + def combOp = (c1: Long, c2: Long) => c1 + c2 + for (depth <- 1 until 10) { + val sum = rdd.treeAggregate(0L)(seqOp, combOp, depth) + assert(sum === -1000L) + } + } + + test("treeReduce") { + val rdd = sc.makeRDD(-1000 until 1000, 10) + for (depth <- 1 until 10) { + val sum = rdd.treeReduce(_ + _, depth) + assert(sum === -1000) + } + } + test("basic caching") { val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2).cache() assert(rdd.collect().toList === List(1, 2, 3, 4)) @@ -545,6 +564,14 @@ class RDDSuite extends FunSuite with SharedSparkContext { assert(sortedTopK === nums.sorted(ord).take(5)) } + test("isEmpty") { + assert(sc.emptyRDD.isEmpty()) + assert(sc.parallelize(Seq[Int]()).isEmpty()) + assert(!sc.parallelize(Seq(1)).isEmpty()) + assert(sc.parallelize(Seq(1,2,3), 3).filter(_ < 0).isEmpty()) + assert(!sc.parallelize(Seq(1,2,3), 3).filter(_ > 1).isEmpty()) + } + test("sample preserves partitioner") { val partitioner = new HashPartitioner(2) val rdd = sc.parallelize(Seq((0, 1), (2, 3))).partitionBy(partitioner) @@ -918,4 +945,45 @@ class RDDSuite extends FunSuite with SharedSparkContext { mutableDependencies += dep } } + + test("nested RDDs are not supported (SPARK-5063)") { + val rdd: RDD[Int] = sc.parallelize(1 to 100) + val rdd2: RDD[Int] = sc.parallelize(1 to 100) + val thrown = intercept[SparkException] { + val nestedRDD: RDD[RDD[Int]] = rdd.mapPartitions { x => Seq(rdd2.map(x => x)).iterator } + nestedRDD.count() + } + assert(thrown.getMessage.contains("SPARK-5063")) + } + + test("actions cannot be performed inside of transformations (SPARK-5063)") { + val rdd: RDD[Int] = sc.parallelize(1 to 100) + val rdd2: RDD[Int] = sc.parallelize(1 to 100) + val thrown = intercept[SparkException] { + rdd.map(x => x * rdd2.count).collect() + } + assert(thrown.getMessage.contains("SPARK-5063")) + } + + test("cannot run actions after SparkContext has been stopped (SPARK-5063)") { + val existingRDD = sc.parallelize(1 to 100) + sc.stop() + val thrown = intercept[IllegalStateException] { + existingRDD.count() + } + assert(thrown.getMessage.contains("shutdown")) + } + + test("cannot call methods on a stopped SparkContext (SPARK-5063)") { + sc.stop() + def assertFails(block: => Any): Unit = { + val thrown = intercept[IllegalStateException] { + block + } + assert(thrown.getMessage.contains("stopped")) + } + assertFails { sc.parallelize(1 to 100) } + assertFails { sc.textFile("/nonexistent-path") } + } + } diff --git a/core/src/test/scala/org/apache/spark/rdd/SortingSuite.scala b/core/src/test/scala/org/apache/spark/rdd/SortingSuite.scala index a40f2ffeffdf9..64b1c24c47168 100644 --- a/core/src/test/scala/org/apache/spark/rdd/SortingSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/SortingSuite.scala @@ -119,5 +119,33 @@ class SortingSuite extends FunSuite with SharedSparkContext with Matchers with L partitions(1).last should be > partitions(2).head partitions(2).last should be > partitions(3).head } + + test("get a range of elements in a sorted RDD that is on one partition") { + val pairArr = (1 to 1000).map(x => (x, x)).toArray + val sorted = sc.parallelize(pairArr, 10).sortByKey() + val range = sorted.filterByRange(20, 40).collect() + assert((20 to 40).toArray === range.map(_._1)) + } + + test("get a range of elements over multiple partitions in a descendingly sorted RDD") { + val pairArr = (1000 to 1 by -1).map(x => (x, x)).toArray + val sorted = sc.parallelize(pairArr, 10).sortByKey(false) + val range = sorted.filterByRange(200, 800).collect() + assert((800 to 200 by -1).toArray === range.map(_._1)) + } + + test("get a range of elements in an array not partitioned by a range partitioner") { + val pairArr = util.Random.shuffle((1 to 1000).toList).map(x => (x, x)) + val pairs = sc.parallelize(pairArr,10) + val range = pairs.filterByRange(200, 800).collect() + assert((800 to 200 by -1).toArray.sorted === range.map(_._1).sorted) + } + + test("get a range of elements over multiple partitions but not taking up full partitions") { + val pairArr = (1000 to 1 by -1).map(x => (x, x)).toArray + val sorted = sc.parallelize(pairArr, 10).sortByKey(false) + val range = sorted.filterByRange(250, 850).collect() + assert((850 to 250 by -1).toArray === range.map(_._1)) + } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index d30eb10bbe947..4bf7f9e647d55 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -19,9 +19,8 @@ package org.apache.spark.scheduler import scala.collection.mutable.{ArrayBuffer, HashSet, HashMap, Map} import scala.language.reflectiveCalls +import scala.util.control.NonFatal -import akka.actor._ -import akka.testkit.{ImplicitSender, TestKit, TestActorRef} import org.scalatest.{BeforeAndAfter, FunSuiteLike} import org.scalatest.concurrent.Timeouts import org.scalatest.time.SpanSugar._ @@ -33,10 +32,16 @@ import org.apache.spark.storage.{BlockId, BlockManagerId, BlockManagerMaster} import org.apache.spark.util.CallSite import org.apache.spark.executor.TaskMetrics -class BuggyDAGEventProcessActor extends Actor { - val state = 0 - def receive = { - case _ => throw new SparkException("error") +class DAGSchedulerEventProcessLoopTester(dagScheduler: DAGScheduler) + extends DAGSchedulerEventProcessLoop(dagScheduler) { + + override def post(event: DAGSchedulerEvent): Unit = { + try { + // Forward event to `onReceive` directly to avoid processing event asynchronously. + onReceive(event) + } catch { + case NonFatal(e) => onError(e) + } } } @@ -65,8 +70,7 @@ class MyRDD( class DAGSchedulerSuiteDummyException extends Exception -class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with FunSuiteLike - with ImplicitSender with BeforeAndAfter with LocalSparkContext with Timeouts { +class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSparkContext with Timeouts { val conf = new SparkConf /** Set of TaskSets the DAGScheduler has requested executed. */ @@ -113,7 +117,7 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F var mapOutputTracker: MapOutputTrackerMaster = null var scheduler: DAGScheduler = null - var dagEventProcessTestActor: TestActorRef[DAGSchedulerEventProcessActor] = null + var dagEventProcessLoopTester: DAGSchedulerEventProcessLoop = null /** * Set of cache locations to return from our mock BlockManagerMaster. @@ -167,13 +171,11 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F runLocallyWithinThread(job) } } - dagEventProcessTestActor = TestActorRef[DAGSchedulerEventProcessActor]( - Props(classOf[DAGSchedulerEventProcessActor], scheduler))(system) + dagEventProcessLoopTester = new DAGSchedulerEventProcessLoopTester(scheduler) } override def afterAll() { super.afterAll() - TestKit.shutdownActorSystem(system) } /** @@ -190,7 +192,7 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F * DAGScheduler event loop. */ private def runEvent(event: DAGSchedulerEvent) { - dagEventProcessTestActor.receive(event) + dagEventProcessLoopTester.post(event) } /** @@ -206,7 +208,7 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F assert(taskSet.tasks.size >= results.size) for ((result, i) <- results.zipWithIndex) { if (i < taskSet.tasks.size) { - runEvent(CompletionEvent(taskSet.tasks(i), result._1, result._2, null, null, null)) + runEvent(CompletionEvent(taskSet.tasks(i), result._1, result._2, null, createFakeTaskInfo(), null)) } } } @@ -217,7 +219,7 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F for ((result, i) <- results.zipWithIndex) { if (i < taskSet.tasks.size) { runEvent(CompletionEvent(taskSet.tasks(i), result._1, result._2, - Map[Long, Any]((accumId, 1)), null, null)) + Map[Long, Any]((accumId, 1)), createFakeTaskInfo(), null)) } } } @@ -397,8 +399,7 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F runLocallyWithinThread(job) } } - dagEventProcessTestActor = TestActorRef[DAGSchedulerEventProcessActor]( - Props(classOf[DAGSchedulerEventProcessActor], noKillScheduler))(system) + dagEventProcessLoopTester = new DAGSchedulerEventProcessLoopTester(noKillScheduler) val jobId = submit(new MyRDD(sc, 1, Nil), Array(0)) cancel(jobId) // Because the job wasn't actually cancelled, we shouldn't have received a failure message. @@ -475,7 +476,7 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0, "ignored"), null, Map[Long, Any](), - null, + createFakeTaskInfo(), null)) assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) assert(sparkListener.failedStages.contains(1)) @@ -486,7 +487,7 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F FetchFailed(makeBlockManagerId("hostA"), shuffleId, 1, 1, "ignored"), null, Map[Long, Any](), - null, + createFakeTaskInfo(), null)) // The SparkListener should not receive redundant failure events. assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) @@ -506,14 +507,14 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F assert(newEpoch > oldEpoch) val taskSet = taskSets(0) // should be ignored for being too old - runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostA", 1), null, null, null)) + runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostA", 1), null, createFakeTaskInfo(), null)) // should work because it's a non-failed host - runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostB", 1), null, null, null)) + runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostB", 1), null, createFakeTaskInfo(), null)) // should be ignored for being too old - runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostA", 1), null, null, null)) + runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostA", 1), null, createFakeTaskInfo(), null)) // should work because it's a new epoch taskSet.tasks(1).epoch = newEpoch - runEvent(CompletionEvent(taskSet.tasks(1), Success, makeMapStatus("hostA", 1), null, null, null)) + runEvent(CompletionEvent(taskSet.tasks(1), Success, makeMapStatus("hostA", 1), null, createFakeTaskInfo(), null)) assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) === Array(makeBlockManagerId("hostB"), makeBlockManagerId("hostA"))) complete(taskSets(1), Seq((Success, 42), (Success, 43))) @@ -726,18 +727,6 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F assert(sc.parallelize(1 to 10, 2).first() === 1) } - test("DAGSchedulerActorSupervisor closes the SparkContext when EventProcessActor crashes") { - val actorSystem = ActorSystem("test") - val supervisor = actorSystem.actorOf( - Props(classOf[DAGSchedulerActorSupervisor], scheduler), "dagSupervisor") - supervisor ! Props[BuggyDAGEventProcessActor] - val child = expectMsgType[ActorRef] - watch(child) - child ! "hi" - expectMsgPF(){ case Terminated(child) => () } - assert(scheduler.sc.dagScheduler === null) - } - test("accumulator not calculated for resubmitted result stage") { //just for register val accum = new Accumulator[Int](0, AccumulatorParam.IntAccumulatorParam) @@ -746,7 +735,11 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F completeWithAccumulator(accum.id, taskSets(0), Seq((Success, 42))) completeWithAccumulator(accum.id, taskSets(0), Seq((Success, 42))) assert(results === Map(0 -> 42)) - assert(Accumulators.originals(accum.id).value === 1) + + val accVal = Accumulators.originals(accum.id).get.get.value + + assert(accVal === 1) + assertDataStructuresEmpty } @@ -777,5 +770,14 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F assert(scheduler.shuffleToMapStage.isEmpty) assert(scheduler.waitingStages.isEmpty) } + + // Nothing in this test should break if the task info's fields are null, but + // OutputCommitCoordinator requires the task info itself to not be null. + private def createFakeTaskInfo(): TaskInfo = { + val info = new TaskInfo(0, 0, 0, 0L, "", "", TaskLocality.ANY, false) + info.finishTime = 1 // to prevent spurious errors in JobProgressListener + info + } + } diff --git a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala new file mode 100644 index 0000000000000..3cc860caa1d9b --- /dev/null +++ b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala @@ -0,0 +1,213 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +import java.io.File +import java.util.concurrent.TimeoutException + +import org.mockito.Matchers +import org.mockito.Mockito._ +import org.mockito.invocation.InvocationOnMock +import org.mockito.stubbing.Answer +import org.scalatest.{BeforeAndAfter, FunSuite} + +import org.apache.hadoop.mapred.{TaskAttemptID, JobConf, TaskAttemptContext, OutputCommitter} + +import org.apache.spark._ +import org.apache.spark.rdd.{RDD, FakeOutputCommitter} +import org.apache.spark.util.Utils + +import scala.concurrent.Await +import scala.concurrent.duration._ +import scala.language.postfixOps + +/** + * Unit tests for the output commit coordination functionality. + * + * The unit test makes both the original task and the speculated task + * attempt to commit, where committing is emulated by creating a + * directory. If both tasks create directories then the end result is + * a failure. + * + * Note that there are some aspects of this test that are less than ideal. + * In particular, the test mocks the speculation-dequeuing logic to always + * dequeue a task and consider it as speculated. Immediately after initially + * submitting the tasks and calling reviveOffers(), reviveOffers() is invoked + * again to pick up the speculated task. This may be hacking the original + * behavior in too much of an unrealistic fashion. + * + * Also, the validation is done by checking the number of files in a directory. + * Ideally, an accumulator would be used for this, where we could increment + * the accumulator in the output committer's commitTask() call. If the call to + * commitTask() was called twice erroneously then the test would ideally fail because + * the accumulator would be incremented twice. + * + * The problem with this test implementation is that when both a speculated task and + * its original counterpart complete, only one of the accumulator's increments is + * captured. This results in a paradox where if the OutputCommitCoordinator logic + * was not in SparkHadoopWriter, the tests would still pass because only one of the + * increments would be captured even though the commit in both tasks was executed + * erroneously. + */ +class OutputCommitCoordinatorSuite extends FunSuite with BeforeAndAfter { + + var outputCommitCoordinator: OutputCommitCoordinator = null + var tempDir: File = null + var sc: SparkContext = null + + before { + tempDir = Utils.createTempDir() + val conf = new SparkConf() + .setMaster("local[4]") + .setAppName(classOf[OutputCommitCoordinatorSuite].getSimpleName) + .set("spark.speculation", "true") + sc = new SparkContext(conf) { + override private[spark] def createSparkEnv( + conf: SparkConf, + isLocal: Boolean, + listenerBus: LiveListenerBus): SparkEnv = { + outputCommitCoordinator = spy(new OutputCommitCoordinator(conf)) + // Use Mockito.spy() to maintain the default infrastructure everywhere else. + // This mocking allows us to control the coordinator responses in test cases. + SparkEnv.createDriverEnv(conf, isLocal, listenerBus, Some(outputCommitCoordinator)) + } + } + // Use Mockito.spy() to maintain the default infrastructure everywhere else + val mockTaskScheduler = spy(sc.taskScheduler.asInstanceOf[TaskSchedulerImpl]) + + doAnswer(new Answer[Unit]() { + override def answer(invoke: InvocationOnMock): Unit = { + // Submit the tasks, then force the task scheduler to dequeue the + // speculated task + invoke.callRealMethod() + mockTaskScheduler.backend.reviveOffers() + } + }).when(mockTaskScheduler).submitTasks(Matchers.any()) + + doAnswer(new Answer[TaskSetManager]() { + override def answer(invoke: InvocationOnMock): TaskSetManager = { + val taskSet = invoke.getArguments()(0).asInstanceOf[TaskSet] + new TaskSetManager(mockTaskScheduler, taskSet, 4) { + var hasDequeuedSpeculatedTask = false + override def dequeueSpeculativeTask( + execId: String, + host: String, + locality: TaskLocality.Value): Option[(Int, TaskLocality.Value)] = { + if (!hasDequeuedSpeculatedTask) { + hasDequeuedSpeculatedTask = true + Some(0, TaskLocality.PROCESS_LOCAL) + } else { + None + } + } + } + } + }).when(mockTaskScheduler).createTaskSetManager(Matchers.any(), Matchers.any()) + + sc.taskScheduler = mockTaskScheduler + val dagSchedulerWithMockTaskScheduler = new DAGScheduler(sc, mockTaskScheduler) + sc.taskScheduler.setDAGScheduler(dagSchedulerWithMockTaskScheduler) + sc.dagScheduler = dagSchedulerWithMockTaskScheduler + } + + after { + sc.stop() + tempDir.delete() + outputCommitCoordinator = null + } + + test("Only one of two duplicate commit tasks should commit") { + val rdd = sc.parallelize(Seq(1), 1) + sc.runJob(rdd, OutputCommitFunctions(tempDir.getAbsolutePath).commitSuccessfully _, + 0 until rdd.partitions.size, allowLocal = false) + assert(tempDir.list().size === 1) + } + + test("If commit fails, if task is retried it should not be locked, and will succeed.") { + val rdd = sc.parallelize(Seq(1), 1) + sc.runJob(rdd, OutputCommitFunctions(tempDir.getAbsolutePath).failFirstCommitAttempt _, + 0 until rdd.partitions.size, allowLocal = false) + assert(tempDir.list().size === 1) + } + + test("Job should not complete if all commits are denied") { + // Create a mock OutputCommitCoordinator that denies all attempts to commit + doReturn(false).when(outputCommitCoordinator).handleAskPermissionToCommit( + Matchers.any(), Matchers.any(), Matchers.any()) + val rdd: RDD[Int] = sc.parallelize(Seq(1), 1) + def resultHandler(x: Int, y: Unit): Unit = {} + val futureAction: SimpleFutureAction[Unit] = sc.submitJob[Int, Unit, Unit](rdd, + OutputCommitFunctions(tempDir.getAbsolutePath).commitSuccessfully, + 0 until rdd.partitions.size, resultHandler, 0) + // It's an error if the job completes successfully even though no committer was authorized, + // so throw an exception if the job was allowed to complete. + intercept[TimeoutException] { + Await.result(futureAction, 5 seconds) + } + assert(tempDir.list().size === 0) + } +} + +/** + * Class with methods that can be passed to runJob to test commits with a mock committer. + */ +private case class OutputCommitFunctions(tempDirPath: String) { + + // Mock output committer that simulates a successful commit (after commit is authorized) + private def successfulOutputCommitter = new FakeOutputCommitter { + override def commitTask(context: TaskAttemptContext): Unit = { + Utils.createDirectory(tempDirPath) + } + } + + // Mock output committer that simulates a failed commit (after commit is authorized) + private def failingOutputCommitter = new FakeOutputCommitter { + override def commitTask(taskAttemptContext: TaskAttemptContext) { + throw new RuntimeException + } + } + + def commitSuccessfully(iter: Iterator[Int]): Unit = { + val ctx = TaskContext.get() + runCommitWithProvidedCommitter(ctx, iter, successfulOutputCommitter) + } + + def failFirstCommitAttempt(iter: Iterator[Int]): Unit = { + val ctx = TaskContext.get() + runCommitWithProvidedCommitter(ctx, iter, + if (ctx.attemptNumber == 0) failingOutputCommitter else successfulOutputCommitter) + } + + private def runCommitWithProvidedCommitter( + ctx: TaskContext, + iter: Iterator[Int], + outputCommitter: OutputCommitter): Unit = { + def jobConf = new JobConf { + override def getOutputCommitter(): OutputCommitter = outputCommitter + } + val sparkHadoopWriter = new SparkHadoopWriter(jobConf) { + override def newTaskAttemptContext( + conf: JobConf, + attemptId: TaskAttemptID): TaskAttemptContext = { + mock(classOf[TaskAttemptContext]) + } + } + sparkHadoopWriter.setup(ctx.stageId, ctx.partitionId, ctx.attemptNumber) + sparkHadoopWriter.commit() + } +} diff --git a/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala index 7e360cc6082ec..702c4cb3bdef9 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala @@ -61,7 +61,7 @@ class ReplayListenerSuite extends FunSuite with BeforeAndAfter { try { val replayer = new ReplayListenerBus() replayer.addListener(eventMonster) - replayer.replay(logData, SPARK_VERSION) + replayer.replay(logData, SPARK_VERSION, logFilePath.toString) } finally { logData.close() } @@ -120,7 +120,7 @@ class ReplayListenerSuite extends FunSuite with BeforeAndAfter { try { val replayer = new ReplayListenerBus() replayer.addListener(eventMonster) - replayer.replay(logData, version) + replayer.replay(logData, version, eventLog.getPath().toString) } finally { logData.close() } diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala index 24f41bf8cccda..3a41ee8d4ae0c 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala @@ -20,23 +20,21 @@ package org.apache.spark.scheduler import java.util.concurrent.Semaphore import scala.collection.mutable +import scala.collection.JavaConversions._ -import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, FunSuite} -import org.scalatest.Matchers +import org.scalatest.{FunSuite, Matchers} -import org.apache.spark.{LocalSparkContext, SparkContext} import org.apache.spark.executor.TaskMetrics import org.apache.spark.util.ResetSystemProperties +import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext} -class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers with BeforeAndAfter - with BeforeAndAfterAll with ResetSystemProperties { +class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers + with ResetSystemProperties { /** Length of time to wait while draining listener events. */ val WAIT_TIMEOUT_MILLIS = 10000 - before { - sc = new SparkContext("local", "SparkListenerSuite") - } + val jobCompletionTime = 1421191296660L test("basic creation and shutdown of LiveListenerBus") { val counter = new BasicJobCounter @@ -44,7 +42,7 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers bus.addListener(counter) // Listener bus hasn't started yet, so posting events should not increment counter - (1 to 5).foreach { _ => bus.post(SparkListenerJobEnd(0, JobSucceeded)) } + (1 to 5).foreach { _ => bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) } assert(counter.count === 0) // Starting listener bus should flush all buffered events @@ -54,7 +52,7 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers // After listener bus has stopped, posting events should not increment counter bus.stop() - (1 to 5).foreach { _ => bus.post(SparkListenerJobEnd(0, JobSucceeded)) } + (1 to 5).foreach { _ => bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) } assert(counter.count === 5) // Listener bus must not be started twice @@ -99,7 +97,7 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers bus.addListener(blockingListener) bus.start() - bus.post(SparkListenerJobEnd(0, JobSucceeded)) + bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) listenerStarted.acquire() // Listener should be blocked after start @@ -125,6 +123,7 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers } test("basic creation of StageInfo") { + sc = new SparkContext("local", "SparkListenerSuite") val listener = new SaveStageAndTaskInfo sc.addSparkListener(listener) val rdd1 = sc.parallelize(1 to 100, 4) @@ -146,6 +145,7 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers } test("basic creation of StageInfo with shuffle") { + sc = new SparkContext("local", "SparkListenerSuite") val listener = new SaveStageAndTaskInfo sc.addSparkListener(listener) val rdd1 = sc.parallelize(1 to 100, 4) @@ -183,6 +183,7 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers } test("StageInfo with fewer tasks than partitions") { + sc = new SparkContext("local", "SparkListenerSuite") val listener = new SaveStageAndTaskInfo sc.addSparkListener(listener) val rdd1 = sc.parallelize(1 to 100, 4) @@ -199,6 +200,7 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers } test("local metrics") { + sc = new SparkContext("local", "SparkListenerSuite") val listener = new SaveStageAndTaskInfo sc.addSparkListener(listener) sc.addSparkListener(new StatsReportListener) @@ -265,6 +267,7 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers } test("onTaskGettingResult() called when result fetched remotely") { + sc = new SparkContext("local", "SparkListenerSuite") val listener = new SaveTaskEvents sc.addSparkListener(listener) @@ -285,6 +288,7 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers } test("onTaskGettingResult() not called when result sent directly") { + sc = new SparkContext("local", "SparkListenerSuite") val listener = new SaveTaskEvents sc.addSparkListener(listener) @@ -300,6 +304,7 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers } test("onTaskEnd() should be called for all started tasks, even after job has been killed") { + sc = new SparkContext("local", "SparkListenerSuite") val WAIT_TIMEOUT_MILLIS = 10000 val listener = new SaveTaskEvents sc.addSparkListener(listener) @@ -345,7 +350,7 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers bus.start() // Post events to all listeners, and wait until the queue is drained - (1 to 5).foreach { _ => bus.post(SparkListenerJobEnd(0, JobSucceeded)) } + (1 to 5).foreach { _ => bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) } assert(bus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) // The exception should be caught, and the event should be propagated to other listeners @@ -354,6 +359,17 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers assert(jobCounter2.count === 5) } + test("registering listeners via spark.extraListeners") { + val conf = new SparkConf().setMaster("local").setAppName("test") + .set("spark.extraListeners", classOf[ListenerThatAcceptsSparkConf].getName + "," + + classOf[BasicJobCounter].getName) + sc = new SparkContext(conf) + sc.listenerBus.listeners.collect { case x: BasicJobCounter => x}.size should be (1) + sc.listenerBus.listeners.collect { + case x: ListenerThatAcceptsSparkConf => x + }.size should be (1) + } + /** * Assert that the given list of numbers has an average that is greater than zero. */ @@ -361,14 +377,6 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers assert(m.sum / m.size.toDouble > 0.0, msg) } - /** - * A simple listener that counts the number of jobs observed. - */ - private class BasicJobCounter extends SparkListener { - var count = 0 - override def onJobEnd(job: SparkListenerJobEnd) = count += 1 - } - /** * A simple listener that saves all task infos and task metrics. */ @@ -421,3 +429,19 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers } } + +// These classes can't be declared inside of the SparkListenerSuite class because we don't want +// their constructors to contain references to SparkListenerSuite: + +/** + * A simple listener that counts the number of jobs observed. + */ +private class BasicJobCounter extends SparkListener { + var count = 0 + override def onJobEnd(job: SparkListenerJobEnd) = count += 1 +} + +private class ListenerThatAcceptsSparkConf(conf: SparkConf) extends SparkListener { + var count = 0 + override def onJobEnd(job: SparkListenerJobEnd) = count += 1 +} diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index 84b9b788237bf..12330d8f63c40 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -27,7 +27,7 @@ import org.scalatest.FunSuite import org.apache.spark._ import org.apache.spark.executor.TaskMetrics -import org.apache.spark.util.FakeClock +import org.apache.spark.util.ManualClock class FakeDAGScheduler(sc: SparkContext, taskScheduler: FakeTaskScheduler) extends DAGScheduler(sc) { @@ -164,7 +164,7 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { sc = new SparkContext("local", "test") val sched = new FakeTaskScheduler(sc, ("exec1", "host1")) val taskSet = FakeTask.createTaskSet(1) - val clock = new FakeClock + val clock = new ManualClock val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) // Offer a host with NO_PREF as the constraint, @@ -213,7 +213,7 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { sc = new SparkContext("local", "test") val sched = new FakeTaskScheduler(sc, ("execA", "host1"), ("execC", "host2")) val taskSet = FakeTask.createTaskSet(1, Seq(TaskLocation("host1", "execB"))) - val clock = new FakeClock + val clock = new ManualClock val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) // An executor that is not NODE_LOCAL should be rejected. @@ -234,7 +234,7 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { Seq(TaskLocation("host1"), TaskLocation("host2", "exec2")), Seq() // Last task has no locality prefs ) - val clock = new FakeClock + val clock = new ManualClock val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) // First offer host1, exec1: first task should be chosen assert(manager.resourceOffer("exec1", "host1", ANY).get.index === 0) @@ -263,7 +263,7 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { Seq(TaskLocation("host2", "exec3")), Seq() // Last task has no locality prefs ) - val clock = new FakeClock + val clock = new ManualClock val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) // First offer host1, exec1: first task should be chosen assert(manager.resourceOffer("exec1", "host1", PROCESS_LOCAL).get.index === 0) @@ -283,7 +283,7 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { Seq(TaskLocation("host3")), Seq(TaskLocation("host2")) ) - val clock = new FakeClock + val clock = new ManualClock val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) // First offer host1: first task should be chosen @@ -314,13 +314,14 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { test("delay scheduling with failed hosts") { sc = new SparkContext("local", "test") - val sched = new FakeTaskScheduler(sc, ("exec1", "host1"), ("exec2", "host2")) + val sched = new FakeTaskScheduler(sc, ("exec1", "host1"), ("exec2", "host2"), + ("exec3", "host3")) val taskSet = FakeTask.createTaskSet(3, Seq(TaskLocation("host1")), Seq(TaskLocation("host2")), Seq(TaskLocation("host3")) ) - val clock = new FakeClock + val clock = new ManualClock val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) // First offer host1: first task should be chosen @@ -352,7 +353,7 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { sc = new SparkContext("local", "test") val sched = new FakeTaskScheduler(sc, ("exec1", "host1")) val taskSet = FakeTask.createTaskSet(1) - val clock = new FakeClock + val clock = new ManualClock val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) assert(manager.resourceOffer("exec1", "host1", ANY).get.index === 0) @@ -369,7 +370,7 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { sc = new SparkContext("local", "test") val sched = new FakeTaskScheduler(sc, ("exec1", "host1")) val taskSet = FakeTask.createTaskSet(1) - val clock = new FakeClock + val clock = new ManualClock val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) // Fail the task MAX_TASK_FAILURES times, and check that the task set is aborted @@ -401,7 +402,7 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { ("exec1.1", "host1"), ("exec2", "host2")) // affinity to exec1 on host1 - which we will fail. val taskSet = FakeTask.createTaskSet(1, Seq(TaskLocation("host1", "exec1"))) - val clock = new FakeClock + val clock = new ManualClock val manager = new TaskSetManager(sched, taskSet, 4, clock) { @@ -485,7 +486,7 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { Seq(TaskLocation("host1", "execB")), Seq(TaskLocation("host2", "execC")), Seq()) - val clock = new FakeClock + val clock = new ManualClock val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) // Only ANY is valid assert(manager.myLocalityLevels.sameElements(Array(NO_PREF, ANY))) @@ -521,7 +522,7 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { val taskSet = FakeTask.createTaskSet(2, Seq(TaskLocation("host1", "execA")), Seq(TaskLocation("host1", "execA"))) - val clock = new FakeClock + val clock = new ManualClock val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) assert(manager.myLocalityLevels.sameElements(Array(PROCESS_LOCAL, NODE_LOCAL, RACK_LOCAL, ANY))) @@ -610,7 +611,7 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { Seq(TaskLocation("host2"), TaskLocation("host1")), Seq(), Seq(TaskLocation("host3", "execC"))) - val clock = new FakeClock + val clock = new ManualClock val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) assert(manager.resourceOffer("execA", "host1", PROCESS_LOCAL).get.index === 0) @@ -636,7 +637,7 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { Seq(TaskLocation("host2")), Seq(), Seq(TaskLocation("host3"))) - val clock = new FakeClock + val clock = new ManualClock val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) // node-local tasks are scheduled without delay @@ -649,6 +650,47 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { assert(manager.resourceOffer("execA", "host3", NO_PREF).get.index === 2) } + test("SPARK-4939: node-local tasks should be scheduled right after process-local tasks finished") { + sc = new SparkContext("local", "test") + val sched = new FakeTaskScheduler(sc, ("execA", "host1"), ("execB", "host2")) + val taskSet = FakeTask.createTaskSet(4, + Seq(TaskLocation("host1")), + Seq(TaskLocation("host2")), + Seq(ExecutorCacheTaskLocation("host1", "execA")), + Seq(ExecutorCacheTaskLocation("host2", "execB"))) + val clock = new ManualClock + val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) + + // process-local tasks are scheduled first + assert(manager.resourceOffer("execA", "host1", NODE_LOCAL).get.index === 2) + assert(manager.resourceOffer("execB", "host2", NODE_LOCAL).get.index === 3) + // node-local tasks are scheduled without delay + assert(manager.resourceOffer("execA", "host1", NODE_LOCAL).get.index === 0) + assert(manager.resourceOffer("execB", "host2", NODE_LOCAL).get.index === 1) + assert(manager.resourceOffer("execA", "host1", NODE_LOCAL) == None) + assert(manager.resourceOffer("execB", "host2", NODE_LOCAL) == None) + } + + test("SPARK-4939: no-pref tasks should be scheduled after process-local tasks finished") { + sc = new SparkContext("local", "test") + val sched = new FakeTaskScheduler(sc, ("execA", "host1"), ("execB", "host2")) + val taskSet = FakeTask.createTaskSet(3, + Seq(), + Seq(ExecutorCacheTaskLocation("host1", "execA")), + Seq(ExecutorCacheTaskLocation("host2", "execB"))) + val clock = new ManualClock + val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) + + // process-local tasks are scheduled first + assert(manager.resourceOffer("execA", "host1", PROCESS_LOCAL).get.index === 1) + assert(manager.resourceOffer("execB", "host2", PROCESS_LOCAL).get.index === 2) + // no-pref tasks are scheduled without delay + assert(manager.resourceOffer("execA", "host1", PROCESS_LOCAL) == None) + assert(manager.resourceOffer("execA", "host1", NODE_LOCAL) == None) + assert(manager.resourceOffer("execA", "host1", NO_PREF).get.index === 0) + assert(manager.resourceOffer("execA", "host1", ANY) == None) + } + test("Ensure TaskSetManager is usable after addition of levels") { // Regression test for SPARK-2931 sc = new SparkContext("local", "test") @@ -656,7 +698,7 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { val taskSet = FakeTask.createTaskSet(2, Seq(TaskLocation("host1", "execA")), Seq(TaskLocation("host2", "execB.1"))) - val clock = new FakeClock + val clock = new ManualClock val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) // Only ANY is valid assert(manager.myLocalityLevels.sameElements(Array(ANY))) @@ -690,7 +732,7 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { Seq(HostTaskLocation("host1")), Seq(HostTaskLocation("host2")), Seq(HDFSCacheTaskLocation("host3"))) - val clock = new FakeClock + val clock = new ManualClock val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) assert(manager.myLocalityLevels.sameElements(Array(PROCESS_LOCAL, NODE_LOCAL, ANY))) sched.removeExecutor("execA") diff --git a/core/src/test/scala/org/apache/spark/scheduler/mesos/MesosSchedulerBackendSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/mesos/MesosSchedulerBackendSuite.scala index 78a30a40bf19a..afbaa9ade811f 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/mesos/MesosSchedulerBackendSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/mesos/MesosSchedulerBackendSuite.scala @@ -17,25 +17,60 @@ package org.apache.spark.scheduler.mesos -import org.scalatest.FunSuite -import org.apache.spark.{SparkConf, SparkContext, LocalSparkContext} -import org.apache.spark.scheduler.{SparkListenerExecutorAdded, LiveListenerBus, - TaskDescription, WorkerOffer, TaskSchedulerImpl} -import org.apache.spark.scheduler.cluster.ExecutorInfo -import org.apache.spark.scheduler.cluster.mesos.{MemoryUtils, MesosSchedulerBackend} -import org.apache.mesos.SchedulerDriver -import org.apache.mesos.Protos.{ExecutorInfo => MesosExecutorInfo, _} -import org.apache.mesos.Protos.Value.Scalar -import org.easymock.{Capture, EasyMock} import java.nio.ByteBuffer -import java.util.Collections import java.util -import org.scalatest.mock.EasyMockSugar +import java.util.Collections import scala.collection.mutable import scala.collection.mutable.ArrayBuffer -class MesosSchedulerBackendSuite extends FunSuite with LocalSparkContext with EasyMockSugar { +import org.apache.mesos.SchedulerDriver +import org.apache.mesos.Protos._ +import org.apache.mesos.Protos.Value.Scalar +import org.mockito.Mockito._ +import org.mockito.Matchers._ +import org.mockito.{ArgumentCaptor, Matchers} +import org.scalatest.FunSuite +import org.scalatest.mock.MockitoSugar + +import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext} +import org.apache.spark.executor.MesosExecutorBackend +import org.apache.spark.scheduler.{LiveListenerBus, SparkListenerExecutorAdded, + TaskDescription, TaskSchedulerImpl, WorkerOffer} +import org.apache.spark.scheduler.cluster.ExecutorInfo +import org.apache.spark.scheduler.cluster.mesos.{MesosSchedulerBackend, MemoryUtils} + +class MesosSchedulerBackendSuite extends FunSuite with LocalSparkContext with MockitoSugar { + + test("check spark-class location correctly") { + val conf = new SparkConf + conf.set("spark.mesos.executor.home" , "/mesos-home") + + val listenerBus = mock[LiveListenerBus] + listenerBus.post( + SparkListenerExecutorAdded(anyLong, "s1", new ExecutorInfo("host1", 2, Map.empty))) + + val sc = mock[SparkContext] + when(sc.getSparkHome()).thenReturn(Option("/spark-home")) + + when(sc.conf).thenReturn(conf) + when(sc.executorEnvs).thenReturn(new mutable.HashMap[String, String]) + when(sc.executorMemory).thenReturn(100) + when(sc.listenerBus).thenReturn(listenerBus) + val taskScheduler = mock[TaskSchedulerImpl] + when(taskScheduler.CPUS_PER_TASK).thenReturn(2) + + val mesosSchedulerBackend = new MesosSchedulerBackend(taskScheduler, sc, "master") + + // uri is null. + val executorInfo = mesosSchedulerBackend.createExecutorInfo("test-id") + assert(executorInfo.getCommand.getValue === s" /mesos-home/bin/spark-class ${classOf[MesosExecutorBackend].getName}") + + // uri exists. + conf.set("spark.executor.uri", "hdfs:///test-app-1.0.0.tgz") + val executorInfo1 = mesosSchedulerBackend.createExecutorInfo("test-id") + assert(executorInfo1.getCommand.getValue === s"cd test-app-1*; ./bin/spark-class ${classOf[MesosExecutorBackend].getName}") + } test("mesos resource offers result in launching tasks") { def createOffer(id: Int, mem: Int, cpu: Int) = { @@ -52,20 +87,19 @@ class MesosSchedulerBackendSuite extends FunSuite with LocalSparkContext with Ea .setSlaveId(SlaveID.newBuilder().setValue(s"s${id.toString}")).setHostname(s"host${id.toString}").build() } - val driver = EasyMock.createMock(classOf[SchedulerDriver]) - val taskScheduler = EasyMock.createMock(classOf[TaskSchedulerImpl]) + val driver = mock[SchedulerDriver] + val taskScheduler = mock[TaskSchedulerImpl] - val listenerBus = EasyMock.createMock(classOf[LiveListenerBus]) - listenerBus.post(SparkListenerExecutorAdded("s1", new ExecutorInfo("host1", 2))) - EasyMock.replay(listenerBus) + val listenerBus = mock[LiveListenerBus] + listenerBus.post( + SparkListenerExecutorAdded(anyLong, "s1", new ExecutorInfo("host1", 2, Map.empty))) - val sc = EasyMock.createMock(classOf[SparkContext]) - EasyMock.expect(sc.executorMemory).andReturn(100).anyTimes() - EasyMock.expect(sc.getSparkHome()).andReturn(Option("/path")).anyTimes() - EasyMock.expect(sc.executorEnvs).andReturn(new mutable.HashMap).anyTimes() - EasyMock.expect(sc.conf).andReturn(new SparkConf).anyTimes() - EasyMock.expect(sc.listenerBus).andReturn(listenerBus) - EasyMock.replay(sc) + val sc = mock[SparkContext] + when(sc.executorMemory).thenReturn(100) + when(sc.getSparkHome()).thenReturn(Option("/path")) + when(sc.executorEnvs).thenReturn(new mutable.HashMap[String, String]) + when(sc.conf).thenReturn(new SparkConf) + when(sc.listenerBus).thenReturn(listenerBus) val minMem = MemoryUtils.calculateTotalMemory(sc).toInt val minCpu = 4 @@ -89,25 +123,29 @@ class MesosSchedulerBackendSuite extends FunSuite with LocalSparkContext with Ea 2 )) val taskDesc = new TaskDescription(1L, 0, "s1", "n1", 0, ByteBuffer.wrap(new Array[Byte](0))) - EasyMock.expect(taskScheduler.resourceOffers(EasyMock.eq(expectedWorkerOffers))).andReturn(Seq(Seq(taskDesc))) - EasyMock.expect(taskScheduler.CPUS_PER_TASK).andReturn(2).anyTimes() - EasyMock.replay(taskScheduler) + when(taskScheduler.resourceOffers(expectedWorkerOffers)).thenReturn(Seq(Seq(taskDesc))) + when(taskScheduler.CPUS_PER_TASK).thenReturn(2) - val capture = new Capture[util.Collection[TaskInfo]] - EasyMock.expect( + val capture = ArgumentCaptor.forClass(classOf[util.Collection[TaskInfo]]) + when( driver.launchTasks( - EasyMock.eq(Collections.singleton(mesosOffers.get(0).getId)), - EasyMock.capture(capture), - EasyMock.anyObject(classOf[Filters]) + Matchers.eq(Collections.singleton(mesosOffers.get(0).getId)), + capture.capture(), + any(classOf[Filters]) ) - ).andReturn(Status.valueOf(1)).once - EasyMock.expect(driver.declineOffer(mesosOffers.get(1).getId)).andReturn(Status.valueOf(1)).times(1) - EasyMock.expect(driver.declineOffer(mesosOffers.get(2).getId)).andReturn(Status.valueOf(1)).times(1) - EasyMock.replay(driver) + ).thenReturn(Status.valueOf(1)) + when(driver.declineOffer(mesosOffers.get(1).getId)).thenReturn(Status.valueOf(1)) + when(driver.declineOffer(mesosOffers.get(2).getId)).thenReturn(Status.valueOf(1)) backend.resourceOffers(driver, mesosOffers) - EasyMock.verify(driver) + verify(driver, times(1)).launchTasks( + Matchers.eq(Collections.singleton(mesosOffers.get(0).getId)), + capture.capture(), + any(classOf[Filters]) + ) + verify(driver, times(1)).declineOffer(mesosOffers.get(1).getId) + verify(driver, times(1)).declineOffer(mesosOffers.get(2).getId) assert(capture.getValue.size() == 1) val taskInfo = capture.getValue.iterator().next() assert(taskInfo.getName.equals("n1")) @@ -119,15 +157,13 @@ class MesosSchedulerBackendSuite extends FunSuite with LocalSparkContext with Ea // Unwanted resources offered on an existing node. Make sure they are declined val mesosOffers2 = new java.util.ArrayList[Offer] mesosOffers2.add(createOffer(1, minMem, minCpu)) - EasyMock.reset(taskScheduler) - EasyMock.reset(driver) - EasyMock.expect(taskScheduler.resourceOffers(EasyMock.anyObject(classOf[Seq[WorkerOffer]])).andReturn(Seq(Seq()))) - EasyMock.expect(taskScheduler.CPUS_PER_TASK).andReturn(2).anyTimes() - EasyMock.replay(taskScheduler) - EasyMock.expect(driver.declineOffer(mesosOffers2.get(0).getId)).andReturn(Status.valueOf(1)).times(1) - EasyMock.replay(driver) + reset(taskScheduler) + reset(driver) + when(taskScheduler.resourceOffers(any(classOf[Seq[WorkerOffer]]))).thenReturn(Seq(Seq())) + when(taskScheduler.CPUS_PER_TASK).thenReturn(2) + when(driver.declineOffer(mesosOffers2.get(0).getId)).thenReturn(Status.valueOf(1)) backend.resourceOffers(driver, mesosOffers2) - EasyMock.verify(driver) + verify(driver, times(1)).declineOffer(mesosOffers2.get(0).getId) } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/mesos/MesosTaskLaunchDataSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/mesos/MesosTaskLaunchDataSuite.scala new file mode 100644 index 0000000000000..86a42a7398e4d --- /dev/null +++ b/core/src/test/scala/org/apache/spark/scheduler/mesos/MesosTaskLaunchDataSuite.scala @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.mesos + +import java.nio.ByteBuffer + +import org.scalatest.FunSuite + +import org.apache.spark.scheduler.cluster.mesos.MesosTaskLaunchData + +class MesosTaskLaunchDataSuite extends FunSuite { + test("serialize and deserialize data must be same") { + val serializedTask = ByteBuffer.allocate(40) + (Range(100, 110).map(serializedTask.putInt(_))) + serializedTask.rewind + val attemptNumber = 100 + val byteString = MesosTaskLaunchData(serializedTask, attemptNumber).toByteString + serializedTask.rewind + val mesosTaskLaunchData = MesosTaskLaunchData.fromByteString(byteString) + assert(mesosTaskLaunchData.attemptNumber == attemptNumber) + assert(mesosTaskLaunchData.serializedTask.equals(serializedTask)) + } +} diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerDistributedSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerDistributedSuite.scala index 855f1b6276089..054a4c64897a9 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerDistributedSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerDistributedSuite.scala @@ -29,9 +29,9 @@ class KryoSerializerDistributedSuite extends FunSuite { test("kryo objects are serialised consistently in different processes") { val conf = new SparkConf(false) - conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") - conf.set("spark.kryo.registrator", classOf[AppJarRegistrator].getName) - conf.set("spark.task.maxFailures", "1") + .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") + .set("spark.kryo.registrator", classOf[AppJarRegistrator].getName) + .set("spark.task.maxFailures", "1") val jar = TestUtils.createJarWithClasses(List(AppJarRegistrator.customClassName)) conf.setJars(List(jar.getPath)) diff --git a/core/src/test/scala/org/apache/spark/serializer/SerializationDebuggerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/SerializationDebuggerSuite.scala new file mode 100644 index 0000000000000..e62828c4fbac6 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/serializer/SerializationDebuggerSuite.scala @@ -0,0 +1,139 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.serializer + +import java.io.{ObjectOutput, ObjectInput} + +import org.scalatest.{BeforeAndAfterEach, FunSuite} + + +class SerializationDebuggerSuite extends FunSuite with BeforeAndAfterEach { + + import SerializationDebugger.find + + override def beforeEach(): Unit = { + SerializationDebugger.enableDebugging = true + } + + test("primitives, strings, and nulls") { + assert(find(1) === List.empty) + assert(find(1L) === List.empty) + assert(find(1.toShort) === List.empty) + assert(find(1.0) === List.empty) + assert(find("1") === List.empty) + assert(find(null) === List.empty) + } + + test("primitive arrays") { + assert(find(Array[Int](1, 2)) === List.empty) + assert(find(Array[Long](1, 2)) === List.empty) + } + + test("non-primitive arrays") { + assert(find(Array("aa", "bb")) === List.empty) + assert(find(Array(new SerializableClass1)) === List.empty) + } + + test("serializable object") { + assert(find(new Foo(1, "b", 'c', 'd', null, null, null)) === List.empty) + } + + test("nested arrays") { + val foo1 = new Foo(1, "b", 'c', 'd', null, null, null) + val foo2 = new Foo(1, "b", 'c', 'd', null, Array(foo1), null) + assert(find(new Foo(1, "b", 'c', 'd', null, Array(foo2), null)) === List.empty) + } + + test("nested objects") { + val foo1 = new Foo(1, "b", 'c', 'd', null, null, null) + val foo2 = new Foo(1, "b", 'c', 'd', null, null, foo1) + assert(find(new Foo(1, "b", 'c', 'd', null, null, foo2)) === List.empty) + } + + test("cycles (should not loop forever)") { + val foo1 = new Foo(1, "b", 'c', 'd', null, null, null) + foo1.g = foo1 + assert(find(new Foo(1, "b", 'c', 'd', null, null, foo1)) === List.empty) + } + + test("root object not serializable") { + val s = find(new NotSerializable) + assert(s.size === 1) + assert(s.head.contains("NotSerializable")) + } + + test("array containing not serializable element") { + val s = find(new SerializableArray(Array(new NotSerializable))) + assert(s.size === 5) + assert(s(0).contains("NotSerializable")) + assert(s(1).contains("element of array")) + assert(s(2).contains("array")) + assert(s(3).contains("arrayField")) + assert(s(4).contains("SerializableArray")) + } + + test("object containing not serializable field") { + val s = find(new SerializableClass2(new NotSerializable)) + assert(s.size === 3) + assert(s(0).contains("NotSerializable")) + assert(s(1).contains("objectField")) + assert(s(2).contains("SerializableClass2")) + } + + test("externalizable class writing out not serializable object") { + val s = find(new ExternalizableClass) + assert(s.size === 5) + assert(s(0).contains("NotSerializable")) + assert(s(1).contains("objectField")) + assert(s(2).contains("SerializableClass2")) + assert(s(3).contains("writeExternal")) + assert(s(4).contains("ExternalizableClass")) + } +} + + +class SerializableClass1 extends Serializable + + +class SerializableClass2(val objectField: Object) extends Serializable + + +class SerializableArray(val arrayField: Array[Object]) extends Serializable + + +class ExternalizableClass extends java.io.Externalizable { + override def writeExternal(out: ObjectOutput): Unit = { + out.writeInt(1) + out.writeObject(new SerializableClass2(new NotSerializable)) + } + + override def readExternal(in: ObjectInput): Unit = {} +} + + +class Foo( + a: Int, + b: String, + c: Char, + d: Byte, + e: Array[Int], + f: Array[Object], + var g: Foo) extends Serializable + + +class NotSerializable diff --git a/core/src/test/scala/org/apache/spark/storage/BlockObjectWriterSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockObjectWriterSuite.scala index bbc7e1357b90d..c21c92b63ad13 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockObjectWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockObjectWriterSuite.scala @@ -31,6 +31,8 @@ class BlockObjectWriterSuite extends FunSuite { new JavaSerializer(new SparkConf()), 1024, os => os, true, writeMetrics) writer.write(Long.box(20)) + // Record metrics update on every write + assert(writeMetrics.shuffleRecordsWritten === 1) // Metrics don't update on every write assert(writeMetrics.shuffleBytesWritten == 0) // After 32 writes, metrics should update @@ -39,6 +41,7 @@ class BlockObjectWriterSuite extends FunSuite { writer.write(Long.box(i)) } assert(writeMetrics.shuffleBytesWritten > 0) + assert(writeMetrics.shuffleRecordsWritten === 33) writer.commitAndClose() assert(file.length() == writeMetrics.shuffleBytesWritten) } @@ -51,6 +54,8 @@ class BlockObjectWriterSuite extends FunSuite { new JavaSerializer(new SparkConf()), 1024, os => os, true, writeMetrics) writer.write(Long.box(20)) + // Record metrics update on every write + assert(writeMetrics.shuffleRecordsWritten === 1) // Metrics don't update on every write assert(writeMetrics.shuffleBytesWritten == 0) // After 32 writes, metrics should update @@ -59,7 +64,23 @@ class BlockObjectWriterSuite extends FunSuite { writer.write(Long.box(i)) } assert(writeMetrics.shuffleBytesWritten > 0) + assert(writeMetrics.shuffleRecordsWritten === 33) writer.revertPartialWritesAndClose() assert(writeMetrics.shuffleBytesWritten == 0) + assert(writeMetrics.shuffleRecordsWritten == 0) + } + + test("Reopening a closed block writer") { + val file = new File("somefile") + file.deleteOnExit() + val writeMetrics = new ShuffleWriteMetrics() + val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file, + new JavaSerializer(new SparkConf()), 1024, os => os, true, writeMetrics) + + writer.open() + writer.close() + intercept[IllegalStateException] { + writer.open() + } } } diff --git a/core/src/test/scala/org/apache/spark/storage/LocalDirsSuite.scala b/core/src/test/scala/org/apache/spark/storage/LocalDirsSuite.scala index dae7bf0e336de..82a82e23eecf2 100644 --- a/core/src/test/scala/org/apache/spark/storage/LocalDirsSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/LocalDirsSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.storage import java.io.File import org.apache.spark.util.Utils -import org.scalatest.FunSuite +import org.scalatest.{BeforeAndAfter, FunSuite} import org.apache.spark.SparkConf @@ -28,7 +28,11 @@ import org.apache.spark.SparkConf /** * Tests for the spark.local.dir and SPARK_LOCAL_DIRS configuration options. */ -class LocalDirsSuite extends FunSuite { +class LocalDirsSuite extends FunSuite with BeforeAndAfter { + + before { + Utils.clearLocalRootDirs() + } test("Utils.getLocalDir() returns a valid directory, even if some local dirs are missing") { // Regression test for SPARK-2974 @@ -49,7 +53,7 @@ class LocalDirsSuite extends FunSuite { } override def clone: SparkConf = { - new MySparkConf().setAll(settings) + new MySparkConf().setAll(getAll) } } // spark.local.dir only contains invalid directories, but that's not a problem since diff --git a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala index e85a436cdba17..6a972381faf14 100644 --- a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala @@ -32,12 +32,21 @@ import org.apache.spark.api.java.StorageLevels import org.apache.spark.shuffle.FetchFailedException /** - * Selenium tests for the Spark Web UI. These tests are not run by default - * because they're slow. + * Selenium tests for the Spark Web UI. */ -@DoNotDiscover -class UISeleniumSuite extends FunSuite with WebBrowser with Matchers { - implicit val webDriver: WebDriver = new HtmlUnitDriver +class UISeleniumSuite extends FunSuite with WebBrowser with Matchers with BeforeAndAfterAll { + + implicit var webDriver: WebDriver = _ + + override def beforeAll(): Unit = { + webDriver = new HtmlUnitDriver + } + + override def afterAll(): Unit = { + if (webDriver != null) { + webDriver.quit() + } + } /** * Create a test SparkContext with the SparkUI enabled. @@ -48,6 +57,7 @@ class UISeleniumSuite extends FunSuite with WebBrowser with Matchers { .setMaster("local") .setAppName("test") .set("spark.ui.enabled", "true") + .set("spark.ui.port", "0") val sc = new SparkContext(conf) assert(sc.ui.isDefined) sc @@ -93,7 +103,7 @@ class UISeleniumSuite extends FunSuite with WebBrowser with Matchers { } eventually(timeout(5 seconds), interval(50 milliseconds)) { go to (sc.ui.get.appUIAddress.stripSuffix("/") + "/stages") - find(id("active")).get.text should be("Active Stages (0)") + find(id("active")) should be(None) // Since we hide empty tables find(id("failed")).get.text should be("Failed Stages (1)") } @@ -105,7 +115,7 @@ class UISeleniumSuite extends FunSuite with WebBrowser with Matchers { } eventually(timeout(5 seconds), interval(50 milliseconds)) { go to (sc.ui.get.appUIAddress.stripSuffix("/") + "/stages") - find(id("active")).get.text should be("Active Stages (0)") + find(id("active")) should be(None) // Since we hide empty tables // The failure occurs before the stage becomes active, hence we should still show only one // failed stage, not two: find(id("failed")).get.text should be("Failed Stages (1)") @@ -167,13 +177,14 @@ class UISeleniumSuite extends FunSuite with WebBrowser with Matchers { test("job progress bars should handle stage / task failures") { withSpark(newSparkContext()) { sc => - val data = sc.parallelize(Seq(1, 2, 3)).map(identity).groupBy(identity) + val data = sc.parallelize(Seq(1, 2, 3), 1).map(identity).groupBy(identity) val shuffleHandle = data.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]].shuffleHandle // Simulate fetch failures: val mappedData = data.map { x => val taskContext = TaskContext.get - if (taskContext.attemptNumber == 0) { // Cause this stage to fail on its first attempt. + if (taskContext.taskAttemptId() == 1) { + // Cause the post-shuffle stage to fail on its first attempt with a single task failure val env = SparkEnv.get val bmAddress = env.blockManager.blockManagerId val shuffleId = shuffleHandle.shuffleId diff --git a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala index f865d8ca04d1b..730a4b54f5aa1 100644 --- a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala @@ -28,6 +28,8 @@ import org.apache.spark.util.Utils class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matchers { + val jobSubmissionTime = 1421191042750L + val jobCompletionTime = 1421191296660L private def createStageStartEvent(stageId: Int) = { val stageInfo = new StageInfo(stageId, 0, stageId.toString, 0, null, "") @@ -46,12 +48,12 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc val stageInfos = stageIds.map { stageId => new StageInfo(stageId, 0, stageId.toString, 0, null, "") } - SparkListenerJobStart(jobId, stageInfos) + SparkListenerJobStart(jobId, jobSubmissionTime, stageInfos) } private def createJobEndEvent(jobId: Int, failed: Boolean = false) = { val result = if (failed) JobFailed(new Exception("dummy failure")) else JobSucceeded - SparkListenerJobEnd(jobId, result) + SparkListenerJobEnd(jobId, jobCompletionTime, result) } private def runJob(listener: SparkListener, jobId: Int, shouldFail: Boolean = false) { @@ -86,6 +88,28 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc listener.completedStages.map(_.stageId).toSet should be (Set(50, 49, 48, 47, 46)) } + test("test clearing of stageIdToActiveJobs") { + val conf = new SparkConf() + conf.set("spark.ui.retainedStages", 5.toString) + val listener = new JobProgressListener(conf) + val jobId = 0 + val stageIds = 1 to 50 + // Start a job with 50 stages + listener.onJobStart(createJobStartEvent(jobId, stageIds)) + for (stageId <- stageIds) { + listener.onStageSubmitted(createStageStartEvent(stageId)) + } + listener.stageIdToActiveJobIds.size should be > 0 + + // Complete the stages and job + for (stageId <- stageIds) { + listener.onStageCompleted(createStageEndEvent(stageId, failed = false)) + } + listener.onJobEnd(createJobEndEvent(jobId, false)) + assertActiveJobsStateIsEmpty(listener) + listener.stageIdToActiveJobIds.size should be (0) + } + test("test LRU eviction of jobs") { val conf = new SparkConf() conf.set("spark.ui.retainedStages", 5.toString) @@ -138,7 +162,7 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc assert(listener.stageIdToData.size === 0) // finish this task, should get updated shuffleRead - shuffleReadMetrics.remoteBytesRead = 1000 + shuffleReadMetrics.incRemoteBytesRead(1000) taskMetrics.setShuffleReadMetrics(Some(shuffleReadMetrics)) var taskInfo = new TaskInfo(1234L, 0, 1, 0L, "exe-1", "host1", TaskLocality.NODE_LOCAL, false) taskInfo.finishTime = 1 @@ -224,18 +248,19 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc val shuffleWriteMetrics = new ShuffleWriteMetrics() taskMetrics.setShuffleReadMetrics(Some(shuffleReadMetrics)) taskMetrics.shuffleWriteMetrics = Some(shuffleWriteMetrics) - shuffleReadMetrics.remoteBytesRead = base + 1 - shuffleReadMetrics.remoteBlocksFetched = base + 2 - shuffleWriteMetrics.shuffleBytesWritten = base + 3 - taskMetrics.executorRunTime = base + 4 - taskMetrics.diskBytesSpilled = base + 5 - taskMetrics.memoryBytesSpilled = base + 6 + shuffleReadMetrics.incRemoteBytesRead(base + 1) + shuffleReadMetrics.incLocalBytesRead(base + 9) + shuffleReadMetrics.incRemoteBlocksFetched(base + 2) + shuffleWriteMetrics.incShuffleBytesWritten(base + 3) + taskMetrics.setExecutorRunTime(base + 4) + taskMetrics.incDiskBytesSpilled(base + 5) + taskMetrics.incMemoryBytesSpilled(base + 6) val inputMetrics = new InputMetrics(DataReadMethod.Hadoop) taskMetrics.setInputMetrics(Some(inputMetrics)) - inputMetrics.addBytesRead(base + 7) + inputMetrics.incBytesRead(base + 7) val outputMetrics = new OutputMetrics(DataWriteMethod.Hadoop) taskMetrics.outputMetrics = Some(outputMetrics) - outputMetrics.bytesWritten = base + 8 + outputMetrics.setBytesWritten(base + 8) taskMetrics } @@ -258,8 +283,8 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc var stage0Data = listener.stageIdToData.get((0, 0)).get var stage1Data = listener.stageIdToData.get((1, 0)).get - assert(stage0Data.shuffleReadBytes == 102) - assert(stage1Data.shuffleReadBytes == 201) + assert(stage0Data.shuffleReadTotalBytes == 220) + assert(stage1Data.shuffleReadTotalBytes == 410) assert(stage0Data.shuffleWriteBytes == 106) assert(stage1Data.shuffleWriteBytes == 203) assert(stage0Data.executorRunTime == 108) @@ -288,8 +313,11 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc stage0Data = listener.stageIdToData.get((0, 0)).get stage1Data = listener.stageIdToData.get((1, 0)).get - assert(stage0Data.shuffleReadBytes == 402) - assert(stage1Data.shuffleReadBytes == 602) + // Task 1235 contributed (100+1)+(100+9) = 210 shuffle bytes, and task 1234 contributed + // (300+1)+(300+9) = 610 total shuffle bytes, so the total for the stage is 820. + assert(stage0Data.shuffleReadTotalBytes == 820) + // Task 1236 contributed 410 shuffle bytes, and task 1237 contributed 810 shuffle bytes. + assert(stage1Data.shuffleReadTotalBytes == 1220) assert(stage0Data.shuffleWriteBytes == 406) assert(stage1Data.shuffleWriteBytes == 606) assert(stage0Data.executorRunTime == 408) diff --git a/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala index 6bbf72e929dcb..6250d50fb7036 100644 --- a/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala @@ -17,7 +17,10 @@ package org.apache.spark.util +import java.util.concurrent.TimeoutException + import scala.concurrent.Await +import scala.util.{Failure, Try} import akka.actor._ @@ -26,6 +29,7 @@ import org.scalatest.FunSuite import org.apache.spark._ import org.apache.spark.scheduler.MapStatus import org.apache.spark.storage.BlockManagerId +import org.apache.spark.SSLSampleConfigs._ /** @@ -47,7 +51,7 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext with ResetSystemPro val masterTracker = new MapOutputTrackerMaster(conf) masterTracker.trackerActor = actorSystem.actorOf( - Props(new MapOutputTrackerMasterActor(masterTracker, conf)), "MapOutputTracker") + Props(new MapOutputTrackerMasterActor(masterTracker, conf)), "MapOutputTracker") val badconf = new SparkConf badconf.set("spark.authenticate", "true") @@ -60,7 +64,7 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext with ResetSystemPro conf = conf, securityManager = securityManagerBad) val slaveTracker = new MapOutputTrackerWorker(conf) val selection = slaveSystem.actorSelection( - s"akka.tcp://spark@localhost:$boundPort/user/MapOutputTracker") + AkkaUtils.address(AkkaUtils.protocol(slaveSystem), "spark", "localhost", boundPort, "MapOutputTracker")) val timeout = AkkaUtils.lookupTimeout(conf) intercept[akka.actor.ActorNotFound] { slaveTracker.trackerActor = Await.result(selection.resolveOne(timeout), timeout) @@ -74,7 +78,7 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext with ResetSystemPro val conf = new SparkConf conf.set("spark.authenticate", "false") conf.set("spark.authenticate.secret", "bad") - val securityManager = new SecurityManager(conf); + val securityManager = new SecurityManager(conf) val hostname = "localhost" val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0, @@ -85,18 +89,18 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext with ResetSystemPro val masterTracker = new MapOutputTrackerMaster(conf) masterTracker.trackerActor = actorSystem.actorOf( - Props(new MapOutputTrackerMasterActor(masterTracker, conf)), "MapOutputTracker") + Props(new MapOutputTrackerMasterActor(masterTracker, conf)), "MapOutputTracker") val badconf = new SparkConf badconf.set("spark.authenticate", "false") badconf.set("spark.authenticate.secret", "good") - val securityManagerBad = new SecurityManager(badconf); + val securityManagerBad = new SecurityManager(badconf) val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0, conf = badconf, securityManager = securityManagerBad) val slaveTracker = new MapOutputTrackerWorker(conf) val selection = slaveSystem.actorSelection( - s"akka.tcp://spark@localhost:$boundPort/user/MapOutputTracker") + AkkaUtils.address(AkkaUtils.protocol(slaveSystem), "spark", "localhost", boundPort, "MapOutputTracker")) val timeout = AkkaUtils.lookupTimeout(conf) slaveTracker.trackerActor = Await.result(selection.resolveOne(timeout), timeout) @@ -124,7 +128,7 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext with ResetSystemPro val conf = new SparkConf conf.set("spark.authenticate", "true") conf.set("spark.authenticate.secret", "good") - val securityManager = new SecurityManager(conf); + val securityManager = new SecurityManager(conf) val hostname = "localhost" val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0, @@ -135,12 +139,12 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext with ResetSystemPro val masterTracker = new MapOutputTrackerMaster(conf) masterTracker.trackerActor = actorSystem.actorOf( - Props(new MapOutputTrackerMasterActor(masterTracker, conf)), "MapOutputTracker") + Props(new MapOutputTrackerMasterActor(masterTracker, conf)), "MapOutputTracker") val goodconf = new SparkConf goodconf.set("spark.authenticate", "true") goodconf.set("spark.authenticate.secret", "good") - val securityManagerGood = new SecurityManager(goodconf); + val securityManagerGood = new SecurityManager(goodconf) assert(securityManagerGood.isAuthenticationEnabled() === true) @@ -148,7 +152,7 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext with ResetSystemPro conf = goodconf, securityManager = securityManagerGood) val slaveTracker = new MapOutputTrackerWorker(conf) val selection = slaveSystem.actorSelection( - s"akka.tcp://spark@localhost:$boundPort/user/MapOutputTracker") + AkkaUtils.address(AkkaUtils.protocol(slaveSystem), "spark", "localhost", boundPort, "MapOutputTracker")) val timeout = AkkaUtils.lookupTimeout(conf) slaveTracker.trackerActor = Await.result(selection.resolveOne(timeout), timeout) @@ -175,7 +179,7 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext with ResetSystemPro conf.set("spark.authenticate", "true") conf.set("spark.authenticate.secret", "good") - val securityManager = new SecurityManager(conf); + val securityManager = new SecurityManager(conf) val hostname = "localhost" val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0, @@ -186,12 +190,12 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext with ResetSystemPro val masterTracker = new MapOutputTrackerMaster(conf) masterTracker.trackerActor = actorSystem.actorOf( - Props(new MapOutputTrackerMasterActor(masterTracker, conf)), "MapOutputTracker") + Props(new MapOutputTrackerMasterActor(masterTracker, conf)), "MapOutputTracker") val badconf = new SparkConf badconf.set("spark.authenticate", "false") badconf.set("spark.authenticate.secret", "bad") - val securityManagerBad = new SecurityManager(badconf); + val securityManagerBad = new SecurityManager(badconf) assert(securityManagerBad.isAuthenticationEnabled() === false) @@ -199,7 +203,7 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext with ResetSystemPro conf = badconf, securityManager = securityManagerBad) val slaveTracker = new MapOutputTrackerWorker(conf) val selection = slaveSystem.actorSelection( - s"akka.tcp://spark@localhost:$boundPort/user/MapOutputTracker") + AkkaUtils.address(AkkaUtils.protocol(slaveSystem), "spark", "localhost", boundPort, "MapOutputTracker")) val timeout = AkkaUtils.lookupTimeout(conf) intercept[akka.actor.ActorNotFound] { slaveTracker.trackerActor = Await.result(selection.resolveOne(timeout), timeout) @@ -209,4 +213,174 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext with ResetSystemPro slaveSystem.shutdown() } + test("remote fetch ssl on") { + val conf = sparkSSLConfig() + val securityManager = new SecurityManager(conf) + + val hostname = "localhost" + val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0, + conf = conf, securityManager = securityManager) + System.setProperty("spark.hostPort", hostname + ":" + boundPort) + + assert(securityManager.isAuthenticationEnabled() === false) + + val masterTracker = new MapOutputTrackerMaster(conf) + masterTracker.trackerActor = actorSystem.actorOf( + Props(new MapOutputTrackerMasterActor(masterTracker, conf)), "MapOutputTracker") + + val slaveConf = sparkSSLConfig() + val securityManagerBad = new SecurityManager(slaveConf) + + val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0, + conf = slaveConf, securityManager = securityManagerBad) + val slaveTracker = new MapOutputTrackerWorker(conf) + val selection = slaveSystem.actorSelection( + AkkaUtils.address(AkkaUtils.protocol(slaveSystem), "spark", "localhost", boundPort, "MapOutputTracker")) + val timeout = AkkaUtils.lookupTimeout(conf) + slaveTracker.trackerActor = Await.result(selection.resolveOne(timeout), timeout) + + assert(securityManagerBad.isAuthenticationEnabled() === false) + + masterTracker.registerShuffle(10, 1) + masterTracker.incrementEpoch() + slaveTracker.updateEpoch(masterTracker.getEpoch) + + val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L)) + masterTracker.registerMapOutput(10, 0, + MapStatus(BlockManagerId("a", "hostA", 1000), Array(1000L))) + masterTracker.incrementEpoch() + slaveTracker.updateEpoch(masterTracker.getEpoch) + + // this should succeed since security off + assert(slaveTracker.getServerStatuses(10, 0).toSeq === + Seq((BlockManagerId("a", "hostA", 1000), size1000))) + + actorSystem.shutdown() + slaveSystem.shutdown() + } + + + test("remote fetch ssl on and security enabled") { + val conf = sparkSSLConfig() + conf.set("spark.authenticate", "true") + conf.set("spark.authenticate.secret", "good") + val securityManager = new SecurityManager(conf) + + val hostname = "localhost" + val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0, + conf = conf, securityManager = securityManager) + System.setProperty("spark.hostPort", hostname + ":" + boundPort) + + assert(securityManager.isAuthenticationEnabled() === true) + + val masterTracker = new MapOutputTrackerMaster(conf) + masterTracker.trackerActor = actorSystem.actorOf( + Props(new MapOutputTrackerMasterActor(masterTracker, conf)), "MapOutputTracker") + + val slaveConf = sparkSSLConfig() + slaveConf.set("spark.authenticate", "true") + slaveConf.set("spark.authenticate.secret", "good") + val securityManagerBad = new SecurityManager(slaveConf) + + val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0, + conf = slaveConf, securityManager = securityManagerBad) + val slaveTracker = new MapOutputTrackerWorker(conf) + val selection = slaveSystem.actorSelection( + AkkaUtils.address(AkkaUtils.protocol(slaveSystem), "spark", "localhost", boundPort, "MapOutputTracker")) + val timeout = AkkaUtils.lookupTimeout(conf) + slaveTracker.trackerActor = Await.result(selection.resolveOne(timeout), timeout) + + assert(securityManagerBad.isAuthenticationEnabled() === true) + + masterTracker.registerShuffle(10, 1) + masterTracker.incrementEpoch() + slaveTracker.updateEpoch(masterTracker.getEpoch) + + val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L)) + masterTracker.registerMapOutput(10, 0, + MapStatus(BlockManagerId("a", "hostA", 1000), Array(1000L))) + masterTracker.incrementEpoch() + slaveTracker.updateEpoch(masterTracker.getEpoch) + + assert(slaveTracker.getServerStatuses(10, 0).toSeq === + Seq((BlockManagerId("a", "hostA", 1000), size1000))) + + actorSystem.shutdown() + slaveSystem.shutdown() + } + + + test("remote fetch ssl on and security enabled - bad credentials") { + val conf = sparkSSLConfig() + conf.set("spark.authenticate", "true") + conf.set("spark.authenticate.secret", "good") + val securityManager = new SecurityManager(conf) + + val hostname = "localhost" + val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0, + conf = conf, securityManager = securityManager) + System.setProperty("spark.hostPort", hostname + ":" + boundPort) + + assert(securityManager.isAuthenticationEnabled() === true) + + val masterTracker = new MapOutputTrackerMaster(conf) + masterTracker.trackerActor = actorSystem.actorOf( + Props(new MapOutputTrackerMasterActor(masterTracker, conf)), "MapOutputTracker") + + val slaveConf = sparkSSLConfig() + slaveConf.set("spark.authenticate", "true") + slaveConf.set("spark.authenticate.secret", "bad") + val securityManagerBad = new SecurityManager(slaveConf) + + val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0, + conf = slaveConf, securityManager = securityManagerBad) + val slaveTracker = new MapOutputTrackerWorker(conf) + val selection = slaveSystem.actorSelection( + AkkaUtils.address(AkkaUtils.protocol(slaveSystem), "spark", "localhost", boundPort, "MapOutputTracker")) + val timeout = AkkaUtils.lookupTimeout(conf) + intercept[akka.actor.ActorNotFound] { + slaveTracker.trackerActor = Await.result(selection.resolveOne(timeout), timeout) + } + + actorSystem.shutdown() + slaveSystem.shutdown() + } + + + test("remote fetch ssl on - untrusted server") { + val conf = sparkSSLConfigUntrusted() + val securityManager = new SecurityManager(conf) + + val hostname = "localhost" + val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0, + conf = conf, securityManager = securityManager) + System.setProperty("spark.hostPort", hostname + ":" + boundPort) + + assert(securityManager.isAuthenticationEnabled() === false) + + val masterTracker = new MapOutputTrackerMaster(conf) + masterTracker.trackerActor = actorSystem.actorOf( + Props(new MapOutputTrackerMasterActor(masterTracker, conf)), "MapOutputTracker") + + val slaveConf = sparkSSLConfig() + val securityManagerBad = new SecurityManager(slaveConf) + + val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0, + conf = slaveConf, securityManager = securityManagerBad) + val slaveTracker = new MapOutputTrackerWorker(conf) + val selection = slaveSystem.actorSelection( + AkkaUtils.address(AkkaUtils.protocol(slaveSystem), "spark", "localhost", boundPort, "MapOutputTracker")) + val timeout = AkkaUtils.lookupTimeout(conf) + val result = Try(Await.result(selection.resolveOne(timeout * 2), timeout)) + + result match { + case Failure(ex: ActorNotFound) => + case Failure(ex: TimeoutException) => + case r => fail(s"$r is neither Failure(ActorNotFound) nor Failure(TimeoutException)") + } + + actorSystem.shutdown() + slaveSystem.shutdown() + } + } diff --git a/core/src/test/scala/org/apache/spark/util/EventLoopSuite.scala b/core/src/test/scala/org/apache/spark/util/EventLoopSuite.scala new file mode 100644 index 0000000000000..1026cb2aa7cae --- /dev/null +++ b/core/src/test/scala/org/apache/spark/util/EventLoopSuite.scala @@ -0,0 +1,206 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import java.util.concurrent.CountDownLatch + +import scala.collection.mutable +import scala.concurrent.duration._ +import scala.language.postfixOps + +import org.scalatest.concurrent.Eventually._ +import org.scalatest.concurrent.Timeouts +import org.scalatest.FunSuite + +class EventLoopSuite extends FunSuite with Timeouts { + + test("EventLoop") { + val buffer = new mutable.ArrayBuffer[Int] with mutable.SynchronizedBuffer[Int] + val eventLoop = new EventLoop[Int]("test") { + + override def onReceive(event: Int): Unit = { + buffer += event + } + + override def onError(e: Throwable): Unit = {} + } + eventLoop.start() + (1 to 100).foreach(eventLoop.post) + eventually(timeout(5 seconds), interval(5 millis)) { + assert((1 to 100) === buffer.toSeq) + } + eventLoop.stop() + } + + test("EventLoop: start and stop") { + val eventLoop = new EventLoop[Int]("test") { + + override def onReceive(event: Int): Unit = {} + + override def onError(e: Throwable): Unit = {} + } + assert(false === eventLoop.isActive) + eventLoop.start() + assert(true === eventLoop.isActive) + eventLoop.stop() + assert(false === eventLoop.isActive) + } + + test("EventLoop: onError") { + val e = new RuntimeException("Oops") + @volatile var receivedError: Throwable = null + val eventLoop = new EventLoop[Int]("test") { + + override def onReceive(event: Int): Unit = { + throw e + } + + override def onError(e: Throwable): Unit = { + receivedError = e + } + } + eventLoop.start() + eventLoop.post(1) + eventually(timeout(5 seconds), interval(5 millis)) { + assert(e === receivedError) + } + eventLoop.stop() + } + + test("EventLoop: error thrown from onError should not crash the event thread") { + val e = new RuntimeException("Oops") + @volatile var receivedError: Throwable = null + val eventLoop = new EventLoop[Int]("test") { + + override def onReceive(event: Int): Unit = { + throw e + } + + override def onError(e: Throwable): Unit = { + receivedError = e + throw new RuntimeException("Oops") + } + } + eventLoop.start() + eventLoop.post(1) + eventually(timeout(5 seconds), interval(5 millis)) { + assert(e === receivedError) + assert(eventLoop.isActive) + } + eventLoop.stop() + } + + test("EventLoop: calling stop multiple times should only call onStop once") { + var onStopTimes = 0 + val eventLoop = new EventLoop[Int]("test") { + + override def onReceive(event: Int): Unit = { + } + + override def onError(e: Throwable): Unit = { + } + + override def onStop(): Unit = { + onStopTimes += 1 + } + } + + eventLoop.start() + + eventLoop.stop() + eventLoop.stop() + eventLoop.stop() + + assert(1 === onStopTimes) + } + + test("EventLoop: post event in multiple threads") { + @volatile var receivedEventsCount = 0 + val eventLoop = new EventLoop[Int]("test") { + + override def onReceive(event: Int): Unit = { + receivedEventsCount += 1 + } + + override def onError(e: Throwable): Unit = { + } + + } + eventLoop.start() + + val threadNum = 5 + val eventsFromEachThread = 100 + (1 to threadNum).foreach { _ => + new Thread() { + override def run(): Unit = { + (1 to eventsFromEachThread).foreach(eventLoop.post) + } + }.start() + } + + eventually(timeout(5 seconds), interval(5 millis)) { + assert(threadNum * eventsFromEachThread === receivedEventsCount) + } + eventLoop.stop() + } + + test("EventLoop: onReceive swallows InterruptException") { + val onReceiveLatch = new CountDownLatch(1) + val eventLoop = new EventLoop[Int]("test") { + + override def onReceive(event: Int): Unit = { + onReceiveLatch.countDown() + try { + Thread.sleep(5000) + } catch { + case ie: InterruptedException => // swallow + } + } + + override def onError(e: Throwable): Unit = { + } + + } + eventLoop.start() + eventLoop.post(1) + failAfter(5 seconds) { + // Wait until we enter `onReceive` + onReceiveLatch.await() + eventLoop.stop() + } + assert(false === eventLoop.isActive) + } + + test("EventLoop: stop in eventThread") { + val eventLoop = new EventLoop[Int]("test") { + + override def onReceive(event: Int): Unit = { + stop() + } + + override def onError(e: Throwable): Unit = { + } + + } + eventLoop.start() + eventLoop.post(1) + eventually(timeout(5 seconds), interval(5 millis)) { + assert(!eventLoop.isActive) + } + } +} diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index 71dfed1289850..a2be724254d7c 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -34,6 +34,12 @@ import org.apache.spark.storage._ class JsonProtocolSuite extends FunSuite { + val jobSubmissionTime = 1421191042750L + val jobCompletionTime = 1421191296660L + + val executorAddedTime = 1421458410000L + val executorRemovedTime = 1421458922000L + test("SparkListenerEvent") { val stageSubmitted = SparkListenerStageSubmitted(makeStageInfo(100, 200, 300, 400L, 500L), properties) @@ -54,9 +60,9 @@ class JsonProtocolSuite extends FunSuite { val stageIds = Seq[Int](1, 2, 3, 4) val stageInfos = stageIds.map(x => makeStageInfo(x, x * 200, x * 300, x * 400L, x * 500L)) - SparkListenerJobStart(10, stageInfos, properties) + SparkListenerJobStart(10, jobSubmissionTime, stageInfos, properties) } - val jobEnd = SparkListenerJobEnd(20, JobSucceeded) + val jobEnd = SparkListenerJobEnd(20, jobCompletionTime, JobSucceeded) val environmentUpdate = SparkListenerEnvironmentUpdate(Map[String, Seq[(String, String)]]( "JVM Information" -> Seq(("GC speed", "9999 objects/s"), ("Java home", "Land of coffee")), "Spark Properties" -> Seq(("Job throughput", "80000 jobs/s, regardless of job type")), @@ -70,9 +76,10 @@ class JsonProtocolSuite extends FunSuite { val unpersistRdd = SparkListenerUnpersistRDD(12345) val applicationStart = SparkListenerApplicationStart("The winner of all", None, 42L, "Garfield") val applicationEnd = SparkListenerApplicationEnd(42L) - val executorAdded = SparkListenerExecutorAdded("exec1", - new ExecutorInfo("Hostee.awesome.com", 11)) - val executorRemoved = SparkListenerExecutorRemoved("exec2") + val logUrlMap = Map("stderr" -> "mystderr", "stdout" -> "mystdout").toMap + val executorAdded = SparkListenerExecutorAdded(executorAddedTime, "exec1", + new ExecutorInfo("Hostee.awesome.com", 11, logUrlMap)) + val executorRemoved = SparkListenerExecutorRemoved(executorRemovedTime, "exec2", "test reason") testEvent(stageSubmitted, stageSubmittedJsonString) testEvent(stageCompleted, stageCompletedJsonString) @@ -94,13 +101,14 @@ class JsonProtocolSuite extends FunSuite { } test("Dependent Classes") { + val logUrlMap = Map("stderr" -> "mystderr", "stdout" -> "mystdout").toMap testRDDInfo(makeRddInfo(2, 3, 4, 5L, 6L)) testStageInfo(makeStageInfo(10, 20, 30, 40L, 50L)) testTaskInfo(makeTaskInfo(999L, 888, 55, 777L, false)) testTaskMetrics(makeTaskMetrics( 33333L, 44444L, 55555L, 66666L, 7, 8, hasHadoopInput = false, hasOutput = false)) testBlockManagerId(BlockManagerId("Hong", "Kong", 500)) - testExecutorInfo(new ExecutorInfo("host", 43)) + testExecutorInfo(new ExecutorInfo("host", 43, logUrlMap)) // StorageLevel testStorageLevel(StorageLevel.NONE) @@ -181,6 +189,34 @@ class JsonProtocolSuite extends FunSuite { assert(newMetrics.inputMetrics.isEmpty) } + test("Input/Output records backwards compatibility") { + // records read were added after 1.2 + val metrics = makeTaskMetrics(1L, 2L, 3L, 4L, 5, 6, + hasHadoopInput = true, hasOutput = true, hasRecords = false) + assert(metrics.inputMetrics.nonEmpty) + assert(metrics.outputMetrics.nonEmpty) + val newJson = JsonProtocol.taskMetricsToJson(metrics) + val oldJson = newJson.removeField { case (field, _) => field == "Records Read" } + .removeField { case (field, _) => field == "Records Written" } + val newMetrics = JsonProtocol.taskMetricsFromJson(oldJson) + assert(newMetrics.inputMetrics.get.recordsRead == 0) + assert(newMetrics.outputMetrics.get.recordsWritten == 0) + } + + test("Shuffle Read/Write records backwards compatibility") { + // records read were added after 1.2 + val metrics = makeTaskMetrics(1L, 2L, 3L, 4L, 5, 6, + hasHadoopInput = false, hasOutput = false, hasRecords = false) + assert(metrics.shuffleReadMetrics.nonEmpty) + assert(metrics.shuffleWriteMetrics.nonEmpty) + val newJson = JsonProtocol.taskMetricsToJson(metrics) + val oldJson = newJson.removeField { case (field, _) => field == "Total Records Read" } + .removeField { case (field, _) => field == "Shuffle Records Written" } + val newMetrics = JsonProtocol.taskMetricsFromJson(oldJson) + assert(newMetrics.shuffleReadMetrics.get.recordsRead == 0) + assert(newMetrics.shuffleWriteMetrics.get.shuffleRecordsWritten == 0) + } + test("OutputMetrics backward compatibility") { // OutputMetrics were added after 1.1 val metrics = makeTaskMetrics(1L, 2L, 3L, 4L, 5, 6, hasHadoopInput = false, hasOutput = true) @@ -224,6 +260,18 @@ class JsonProtocolSuite extends FunSuite { assert(expectedFetchFailed === JsonProtocol.taskEndReasonFromJson(oldEvent)) } + test("ShuffleReadMetrics: Local bytes read and time taken backwards compatibility") { + // Metrics about local shuffle bytes read and local read time were added in 1.3.1. + val metrics = makeTaskMetrics(1L, 2L, 3L, 4L, 5, 6, + hasHadoopInput = false, hasOutput = false, hasRecords = false) + assert(metrics.shuffleReadMetrics.nonEmpty) + val newJson = JsonProtocol.taskMetricsToJson(metrics) + val oldJson = newJson.removeField { case (field, _) => field == "Local Bytes Read" } + .removeField { case (field, _) => field == "Local Read Time" } + val newMetrics = JsonProtocol.taskMetricsFromJson(oldJson) + assert(newMetrics.shuffleReadMetrics.get.localBytesRead == 0) + } + test("SparkListenerApplicationStart backwards compatibility") { // SparkListenerApplicationStart in Spark 1.0.0 do not have an "appId" property. val applicationStart = SparkListenerApplicationStart("test", None, 1L, "user") @@ -247,13 +295,31 @@ class JsonProtocolSuite extends FunSuite { val stageInfos = stageIds.map(x => makeStageInfo(x, x * 200, x * 300, x * 400, x * 500)) val dummyStageInfos = stageIds.map(id => new StageInfo(id, 0, "unknown", 0, Seq.empty, "unknown")) - val jobStart = SparkListenerJobStart(10, stageInfos, properties) + val jobStart = SparkListenerJobStart(10, jobSubmissionTime, stageInfos, properties) val oldEvent = JsonProtocol.jobStartToJson(jobStart).removeField({_._1 == "Stage Infos"}) val expectedJobStart = - SparkListenerJobStart(10, dummyStageInfos, properties) + SparkListenerJobStart(10, jobSubmissionTime, dummyStageInfos, properties) assertEquals(expectedJobStart, JsonProtocol.jobStartFromJson(oldEvent)) } + test("SparkListenerJobStart and SparkListenerJobEnd backward compatibility") { + // Prior to Spark 1.3.0, SparkListenerJobStart did not have a "Submission Time" property. + // Also, SparkListenerJobEnd did not have a "Completion Time" property. + val stageIds = Seq[Int](1, 2, 3, 4) + val stageInfos = stageIds.map(x => makeStageInfo(x * 10, x * 20, x * 30, x * 40, x * 50)) + val jobStart = SparkListenerJobStart(11, jobSubmissionTime, stageInfos, properties) + val oldStartEvent = JsonProtocol.jobStartToJson(jobStart) + .removeField({ _._1 == "Submission Time"}) + val expectedJobStart = SparkListenerJobStart(11, -1, stageInfos, properties) + assertEquals(expectedJobStart, JsonProtocol.jobStartFromJson(oldStartEvent)) + + val jobEnd = SparkListenerJobEnd(11, jobCompletionTime, JobSucceeded) + val oldEndEvent = JsonProtocol.jobEndToJson(jobEnd) + .removeField({ _._1 == "Completion Time"}) + val expectedJobEnd = SparkListenerJobEnd(11, -1, JobSucceeded) + assertEquals(expectedJobEnd, JsonProtocol.jobEndFromJson(oldEndEvent)) + } + /** -------------------------- * | Helper test running methods | * --------------------------- */ @@ -618,36 +684,42 @@ class JsonProtocolSuite extends FunSuite { e: Int, f: Int, hasHadoopInput: Boolean, - hasOutput: Boolean) = { + hasOutput: Boolean, + hasRecords: Boolean = true) = { val t = new TaskMetrics - t.hostname = "localhost" - t.executorDeserializeTime = a - t.executorRunTime = b - t.resultSize = c - t.jvmGCTime = d - t.resultSerializationTime = a + b - t.memoryBytesSpilled = a + c + t.setHostname("localhost") + t.setExecutorDeserializeTime(a) + t.setExecutorRunTime(b) + t.setResultSize(c) + t.setJvmGCTime(d) + t.setResultSerializationTime(a + b) + t.incMemoryBytesSpilled(a + c) if (hasHadoopInput) { val inputMetrics = new InputMetrics(DataReadMethod.Hadoop) - inputMetrics.addBytesRead(d + e + f) + inputMetrics.incBytesRead(d + e + f) + inputMetrics.incRecordsRead(if (hasRecords) (d + e + f) / 100 else -1) t.setInputMetrics(Some(inputMetrics)) } else { val sr = new ShuffleReadMetrics - sr.remoteBytesRead = b + d - sr.localBlocksFetched = e - sr.fetchWaitTime = a + d - sr.remoteBlocksFetched = f + sr.incRemoteBytesRead(b + d) + sr.incLocalBlocksFetched(e) + sr.incFetchWaitTime(a + d) + sr.incRemoteBlocksFetched(f) + sr.incRecordsRead(if (hasRecords) (b + d) / 100 else -1) + sr.incLocalBytesRead(a + f) t.setShuffleReadMetrics(Some(sr)) } if (hasOutput) { val outputMetrics = new OutputMetrics(DataWriteMethod.Hadoop) - outputMetrics.bytesWritten = a + b + c + outputMetrics.setBytesWritten(a + b + c) + outputMetrics.setRecordsWritten(if (hasRecords) (a + b + c)/100 else -1) t.outputMetrics = Some(outputMetrics) } else { val sw = new ShuffleWriteMetrics - sw.shuffleBytesWritten = a + b + c - sw.shuffleWriteTime = b + c + d + sw.incShuffleBytesWritten(a + b + c) + sw.incShuffleWriteTime(b + c + d) + sw.setShuffleRecordsWritten(if (hasRecords) (a + b + c) / 100 else -1) t.shuffleWriteMetrics = Some(sw) } // Make at most 6 blocks @@ -881,11 +953,14 @@ class JsonProtocolSuite extends FunSuite { | "Remote Blocks Fetched": 800, | "Local Blocks Fetched": 700, | "Fetch Wait Time": 900, - | "Remote Bytes Read": 1000 + | "Remote Bytes Read": 1000, + | "Local Bytes Read": 1100, + | "Total Records Read" : 10 | }, | "Shuffle Write Metrics": { | "Shuffle Bytes Written": 1200, - | "Shuffle Write Time": 1500 + | "Shuffle Write Time": 1500, + | "Shuffle Records Written": 12 | }, | "Updated Blocks": [ | { @@ -962,11 +1037,13 @@ class JsonProtocolSuite extends FunSuite { | "Disk Bytes Spilled": 0, | "Shuffle Write Metrics": { | "Shuffle Bytes Written": 1200, - | "Shuffle Write Time": 1500 + | "Shuffle Write Time": 1500, + | "Shuffle Records Written": 12 | }, | "Input Metrics": { | "Data Read Method": "Hadoop", - | "Bytes Read": 2100 + | "Bytes Read": 2100, + | "Records Read": 21 | }, | "Updated Blocks": [ | { @@ -1043,11 +1120,13 @@ class JsonProtocolSuite extends FunSuite { | "Disk Bytes Spilled": 0, | "Input Metrics": { | "Data Read Method": "Hadoop", - | "Bytes Read": 2100 + | "Bytes Read": 2100, + | "Records Read": 21 | }, | "Output Metrics": { | "Data Write Method": "Hadoop", - | "Bytes Written": 1200 + | "Bytes Written": 1200, + | "Records Written": 12 | }, | "Updated Blocks": [ | { @@ -1075,6 +1154,7 @@ class JsonProtocolSuite extends FunSuite { |{ | "Event": "SparkListenerJobStart", | "Job ID": 10, + | "Submission Time": 1421191042750, | "Stage Infos": [ | { | "Stage ID": 1, @@ -1349,6 +1429,7 @@ class JsonProtocolSuite extends FunSuite { |{ | "Event": "SparkListenerJobEnd", | "Job ID": 20, + | "Completion Time": 1421191296660, | "Job Result": { | "Result": "JobSucceeded" | } @@ -1430,22 +1511,29 @@ class JsonProtocolSuite extends FunSuite { """ private val executorAddedJsonString = - """ + s""" |{ | "Event": "SparkListenerExecutorAdded", + | "Timestamp": ${executorAddedTime}, | "Executor ID": "exec1", | "Executor Info": { | "Host": "Hostee.awesome.com", - | "Total Cores": 11 + | "Total Cores": 11, + | "Log Urls" : { + | "stderr" : "mystderr", + | "stdout" : "mystdout" + | } | } |} """ private val executorRemovedJsonString = - """ + s""" |{ | "Event": "SparkListenerExecutorRemoved", - | "Executor ID": "exec2" + | "Timestamp": ${executorRemovedTime}, + | "Executor ID": "exec2", + | "Removed Reason": "test reason" |} """ } diff --git a/core/src/test/scala/org/apache/spark/executor/ExecutorURLClassLoaderSuite.scala b/core/src/test/scala/org/apache/spark/util/MutableURLClassLoaderSuite.scala similarity index 75% rename from core/src/test/scala/org/apache/spark/executor/ExecutorURLClassLoaderSuite.scala rename to core/src/test/scala/org/apache/spark/util/MutableURLClassLoaderSuite.scala index e2050e95a1b88..31e3b7e7bb71b 100644 --- a/core/src/test/scala/org/apache/spark/executor/ExecutorURLClassLoaderSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/MutableURLClassLoaderSuite.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.executor +package org.apache.spark.util import java.net.URLClassLoader @@ -24,32 +24,40 @@ import org.scalatest.FunSuite import org.apache.spark.{LocalSparkContext, SparkContext, SparkException, TestUtils} import org.apache.spark.util.Utils -class ExecutorURLClassLoaderSuite extends FunSuite { +class MutableURLClassLoaderSuite extends FunSuite { - val childClassNames = List("FakeClass1", "FakeClass2") - val parentClassNames = List("FakeClass1", "FakeClass2", "FakeClass3") - val urls = List(TestUtils.createJarWithClasses(childClassNames, "1")).toArray - val urls2 = List(TestUtils.createJarWithClasses(parentClassNames, "2")).toArray + val urls2 = List(TestUtils.createJarWithClasses( + classNames = Seq("FakeClass1", "FakeClass2", "FakeClass3"), + toStringValue = "2")).toArray + val urls = List(TestUtils.createJarWithClasses( + classNames = Seq("FakeClass1"), + classNamesWithBase = Seq(("FakeClass2", "FakeClass3")), // FakeClass3 is in parent + toStringValue = "1", + classpathUrls = urls2)).toArray test("child first") { val parentLoader = new URLClassLoader(urls2, null) - val classLoader = new ChildExecutorURLClassLoader(urls, parentLoader) + val classLoader = new ChildFirstURLClassLoader(urls, parentLoader) val fakeClass = classLoader.loadClass("FakeClass2").newInstance() val fakeClassVersion = fakeClass.toString assert(fakeClassVersion === "1") + val fakeClass2 = classLoader.loadClass("FakeClass2").newInstance() + assert(fakeClass.getClass === fakeClass2.getClass) } test("parent first") { val parentLoader = new URLClassLoader(urls2, null) - val classLoader = new ExecutorURLClassLoader(urls, parentLoader) + val classLoader = new MutableURLClassLoader(urls, parentLoader) val fakeClass = classLoader.loadClass("FakeClass1").newInstance() val fakeClassVersion = fakeClass.toString assert(fakeClassVersion === "2") + val fakeClass2 = classLoader.loadClass("FakeClass1").newInstance() + assert(fakeClass.getClass === fakeClass2.getClass) } test("child first can fall back") { val parentLoader = new URLClassLoader(urls2, null) - val classLoader = new ChildExecutorURLClassLoader(urls, parentLoader) + val classLoader = new ChildFirstURLClassLoader(urls, parentLoader) val fakeClass = classLoader.loadClass("FakeClass3").newInstance() val fakeClassVersion = fakeClass.toString assert(fakeClassVersion === "2") @@ -57,7 +65,7 @@ class ExecutorURLClassLoaderSuite extends FunSuite { test("child first can fail") { val parentLoader = new URLClassLoader(urls2, null) - val classLoader = new ChildExecutorURLClassLoader(urls, parentLoader) + val classLoader = new ChildFirstURLClassLoader(urls, parentLoader) intercept[java.lang.ClassNotFoundException] { classLoader.loadClass("FakeClassDoesNotExist").newInstance() } diff --git a/core/src/test/scala/org/apache/spark/util/ResetSystemProperties.scala b/core/src/test/scala/org/apache/spark/util/ResetSystemProperties.scala index d4b92f33dd9e6..bad1aa99952cf 100644 --- a/core/src/test/scala/org/apache/spark/util/ResetSystemProperties.scala +++ b/core/src/test/scala/org/apache/spark/util/ResetSystemProperties.scala @@ -19,6 +19,7 @@ package org.apache.spark.util import java.util.Properties +import org.apache.commons.lang3.SerializationUtils import org.scalatest.{BeforeAndAfterEach, Suite} /** @@ -42,7 +43,11 @@ private[spark] trait ResetSystemProperties extends BeforeAndAfterEach { this: Su var oldProperties: Properties = null override def beforeEach(): Unit = { - oldProperties = new Properties(System.getProperties) + // we need SerializationUtils.clone instead of `new Properties(System.getProperties()` because + // the later way of creating a copy does not copy the properties but it initializes a new + // Properties object with the given properties as defaults. They are not recognized at all + // by standard Scala wrapper over Java Properties then. + oldProperties = SerializationUtils.clone(System.getProperties) super.beforeEach() } diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index 4544382094f96..fe2b644251157 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -29,6 +29,9 @@ import com.google.common.base.Charsets.UTF_8 import com.google.common.io.Files import org.scalatest.FunSuite +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path + import org.apache.spark.SparkConf class UtilsSuite extends FunSuite with ResetSystemProperties { @@ -381,4 +384,32 @@ class UtilsSuite extends FunSuite with ResetSystemProperties { require(cnt === 2, "prepare should be called twice") require(time < 500, "preparation time should not count") } + + test("fetch hcfs dir") { + val tempDir = Utils.createTempDir() + val innerTempDir = Utils.createTempDir(tempDir.getPath) + val tempFile = File.createTempFile("someprefix", "somesuffix", innerTempDir) + val targetDir = new File("target-dir") + Files.write("some text", tempFile, UTF_8) + + try { + val path = new Path("file://" + tempDir.getAbsolutePath) + val conf = new Configuration() + val fs = Utils.getHadoopFileSystem(path.toString, conf) + Utils.fetchHcfsFile(path, targetDir, fs, new SparkConf(), conf, false) + assert(targetDir.exists()) + assert(targetDir.isDirectory()) + val newInnerDir = new File(targetDir, innerTempDir.getName) + println("inner temp dir: " + innerTempDir.getName) + targetDir.listFiles().map(_.getName).foreach(println) + assert(newInnerDir.exists()) + assert(newInnerDir.isDirectory()) + val newInnerFile = new File(newInnerDir, tempFile.getName) + assert(newInnerFile.exists()) + assert(newInnerFile.isFile()) + } finally { + Utils.deleteRecursively(tempDir) + Utils.deleteRecursively(targetDir) + } + } } diff --git a/core/src/test/scala/org/apache/sparktest/ImplicitSuite.scala b/core/src/test/scala/org/apache/sparktest/ImplicitSuite.scala index 4918e2d92beb4..daa795a043495 100644 --- a/core/src/test/scala/org/apache/sparktest/ImplicitSuite.scala +++ b/core/src/test/scala/org/apache/sparktest/ImplicitSuite.scala @@ -44,13 +44,21 @@ class ImplicitSuite { } def testRddToSequenceFileRDDFunctions(): Unit = { - // TODO eliminating `import intToIntWritable` needs refactoring SequenceFileRDDFunctions. - // That will be a breaking change. - import org.apache.spark.SparkContext.intToIntWritable val rdd: org.apache.spark.rdd.RDD[(Int, Int)] = mockRDD rdd.saveAsSequenceFile("/a/test/path") } + def testRddToSequenceFileRDDFunctionsWithWritable(): Unit = { + val rdd: org.apache.spark.rdd.RDD[(org.apache.hadoop.io.IntWritable, org.apache.hadoop.io.Text)] + = mockRDD + rdd.saveAsSequenceFile("/a/test/path") + } + + def testRddToSequenceFileRDDFunctionsWithBytesArray(): Unit = { + val rdd: org.apache.spark.rdd.RDD[(Int, Array[Byte])] = mockRDD + rdd.saveAsSequenceFile("/a/test/path") + } + def testRddToOrderedRDDFunctions(): Unit = { val rdd: org.apache.spark.rdd.RDD[(Int, Int)] = mockRDD rdd.sortByKey() diff --git a/data/mllib/als/sample_movielens_movies.txt b/data/mllib/als/sample_movielens_movies.txt new file mode 100644 index 0000000000000..934a0253849e1 --- /dev/null +++ b/data/mllib/als/sample_movielens_movies.txt @@ -0,0 +1,100 @@ +0::Movie 0::Romance|Comedy +1::Movie 1::Action|Anime +2::Movie 2::Romance|Thriller +3::Movie 3::Action|Romance +4::Movie 4::Anime|Comedy +5::Movie 5::Action|Action +6::Movie 6::Action|Comedy +7::Movie 7::Anime|Comedy +8::Movie 8::Comedy|Action +9::Movie 9::Anime|Thriller +10::Movie 10::Action|Anime +11::Movie 11::Action|Anime +12::Movie 12::Anime|Comedy +13::Movie 13::Thriller|Action +14::Movie 14::Anime|Comedy +15::Movie 15::Comedy|Thriller +16::Movie 16::Anime|Romance +17::Movie 17::Thriller|Action +18::Movie 18::Action|Comedy +19::Movie 19::Anime|Romance +20::Movie 20::Action|Anime +21::Movie 21::Romance|Thriller +22::Movie 22::Romance|Romance +23::Movie 23::Comedy|Comedy +24::Movie 24::Anime|Action +25::Movie 25::Comedy|Comedy +26::Movie 26::Anime|Romance +27::Movie 27::Anime|Anime +28::Movie 28::Thriller|Anime +29::Movie 29::Anime|Romance +30::Movie 30::Thriller|Romance +31::Movie 31::Thriller|Romance +32::Movie 32::Comedy|Anime +33::Movie 33::Comedy|Comedy +34::Movie 34::Anime|Anime +35::Movie 35::Action|Thriller +36::Movie 36::Anime|Romance +37::Movie 37::Romance|Anime +38::Movie 38::Thriller|Romance +39::Movie 39::Romance|Comedy +40::Movie 40::Action|Anime +41::Movie 41::Comedy|Thriller +42::Movie 42::Comedy|Action +43::Movie 43::Thriller|Anime +44::Movie 44::Anime|Action +45::Movie 45::Comedy|Romance +46::Movie 46::Comedy|Action +47::Movie 47::Romance|Comedy +48::Movie 48::Action|Comedy +49::Movie 49::Romance|Romance +50::Movie 50::Comedy|Romance +51::Movie 51::Action|Action +52::Movie 52::Thriller|Action +53::Movie 53::Action|Action +54::Movie 54::Romance|Thriller +55::Movie 55::Anime|Romance +56::Movie 56::Comedy|Action +57::Movie 57::Action|Anime +58::Movie 58::Thriller|Romance +59::Movie 59::Thriller|Comedy +60::Movie 60::Anime|Comedy +61::Movie 61::Comedy|Action +62::Movie 62::Comedy|Romance +63::Movie 63::Romance|Thriller +64::Movie 64::Romance|Action +65::Movie 65::Anime|Romance +66::Movie 66::Comedy|Action +67::Movie 67::Thriller|Anime +68::Movie 68::Thriller|Romance +69::Movie 69::Action|Comedy +70::Movie 70::Thriller|Thriller +71::Movie 71::Action|Comedy +72::Movie 72::Thriller|Romance +73::Movie 73::Comedy|Action +74::Movie 74::Action|Action +75::Movie 75::Action|Action +76::Movie 76::Comedy|Comedy +77::Movie 77::Comedy|Comedy +78::Movie 78::Comedy|Comedy +79::Movie 79::Thriller|Thriller +80::Movie 80::Comedy|Anime +81::Movie 81::Comedy|Anime +82::Movie 82::Romance|Anime +83::Movie 83::Comedy|Thriller +84::Movie 84::Anime|Action +85::Movie 85::Thriller|Anime +86::Movie 86::Romance|Anime +87::Movie 87::Thriller|Thriller +88::Movie 88::Romance|Thriller +89::Movie 89::Action|Anime +90::Movie 90::Anime|Romance +91::Movie 91::Anime|Thriller +92::Movie 92::Action|Comedy +93::Movie 93::Romance|Thriller +94::Movie 94::Thriller|Comedy +95::Movie 95::Action|Action +96::Movie 96::Thriller|Romance +97::Movie 97::Thriller|Thriller +98::Movie 98::Thriller|Comedy +99::Movie 99::Thriller|Romance diff --git a/data/mllib/als/sample_movielens_ratings.txt b/data/mllib/als/sample_movielens_ratings.txt new file mode 100644 index 0000000000000..0889142950797 --- /dev/null +++ b/data/mllib/als/sample_movielens_ratings.txt @@ -0,0 +1,1501 @@ +0::2::3::1424380312 +0::3::1::1424380312 +0::5::2::1424380312 +0::9::4::1424380312 +0::11::1::1424380312 +0::12::2::1424380312 +0::15::1::1424380312 +0::17::1::1424380312 +0::19::1::1424380312 +0::21::1::1424380312 +0::23::1::1424380312 +0::26::3::1424380312 +0::27::1::1424380312 +0::28::1::1424380312 +0::29::1::1424380312 +0::30::1::1424380312 +0::31::1::1424380312 +0::34::1::1424380312 +0::37::1::1424380312 +0::41::2::1424380312 +0::44::1::1424380312 +0::45::2::1424380312 +0::46::1::1424380312 +0::47::1::1424380312 +0::48::1::1424380312 +0::50::1::1424380312 +0::51::1::1424380312 +0::54::1::1424380312 +0::55::1::1424380312 +0::59::2::1424380312 +0::61::2::1424380312 +0::64::1::1424380312 +0::67::1::1424380312 +0::68::1::1424380312 +0::69::1::1424380312 +0::71::1::1424380312 +0::72::1::1424380312 +0::77::2::1424380312 +0::79::1::1424380312 +0::83::1::1424380312 +0::87::1::1424380312 +0::89::2::1424380312 +0::91::3::1424380312 +0::92::4::1424380312 +0::94::1::1424380312 +0::95::2::1424380312 +0::96::1::1424380312 +0::98::1::1424380312 +0::99::1::1424380312 +1::2::2::1424380312 +1::3::1::1424380312 +1::4::2::1424380312 +1::6::1::1424380312 +1::9::3::1424380312 +1::12::1::1424380312 +1::13::1::1424380312 +1::14::1::1424380312 +1::16::1::1424380312 +1::19::1::1424380312 +1::21::3::1424380312 +1::27::1::1424380312 +1::28::3::1424380312 +1::33::1::1424380312 +1::36::2::1424380312 +1::37::1::1424380312 +1::40::1::1424380312 +1::41::2::1424380312 +1::43::1::1424380312 +1::44::1::1424380312 +1::47::1::1424380312 +1::50::1::1424380312 +1::54::1::1424380312 +1::56::2::1424380312 +1::57::1::1424380312 +1::58::1::1424380312 +1::60::1::1424380312 +1::62::4::1424380312 +1::63::1::1424380312 +1::67::1::1424380312 +1::68::4::1424380312 +1::70::2::1424380312 +1::72::1::1424380312 +1::73::1::1424380312 +1::74::2::1424380312 +1::76::1::1424380312 +1::77::3::1424380312 +1::78::1::1424380312 +1::81::1::1424380312 +1::82::1::1424380312 +1::85::3::1424380312 +1::86::2::1424380312 +1::88::2::1424380312 +1::91::1::1424380312 +1::92::2::1424380312 +1::93::1::1424380312 +1::94::2::1424380312 +1::96::1::1424380312 +1::97::1::1424380312 +2::4::3::1424380312 +2::6::1::1424380312 +2::8::5::1424380312 +2::9::1::1424380312 +2::10::1::1424380312 +2::12::3::1424380312 +2::13::1::1424380312 +2::15::2::1424380312 +2::18::2::1424380312 +2::19::4::1424380312 +2::22::1::1424380312 +2::26::1::1424380312 +2::28::1::1424380312 +2::34::4::1424380312 +2::35::1::1424380312 +2::37::5::1424380312 +2::38::1::1424380312 +2::39::5::1424380312 +2::40::4::1424380312 +2::47::1::1424380312 +2::50::1::1424380312 +2::52::2::1424380312 +2::54::1::1424380312 +2::55::1::1424380312 +2::57::2::1424380312 +2::58::2::1424380312 +2::59::1::1424380312 +2::61::1::1424380312 +2::62::1::1424380312 +2::64::1::1424380312 +2::65::1::1424380312 +2::66::3::1424380312 +2::68::1::1424380312 +2::71::3::1424380312 +2::76::1::1424380312 +2::77::1::1424380312 +2::78::1::1424380312 +2::80::1::1424380312 +2::83::5::1424380312 +2::85::1::1424380312 +2::87::2::1424380312 +2::88::1::1424380312 +2::89::4::1424380312 +2::90::1::1424380312 +2::92::4::1424380312 +2::93::5::1424380312 +3::0::1::1424380312 +3::1::1::1424380312 +3::2::1::1424380312 +3::7::3::1424380312 +3::8::3::1424380312 +3::9::1::1424380312 +3::14::1::1424380312 +3::15::1::1424380312 +3::16::1::1424380312 +3::18::4::1424380312 +3::19::1::1424380312 +3::24::3::1424380312 +3::26::1::1424380312 +3::29::3::1424380312 +3::33::1::1424380312 +3::34::3::1424380312 +3::35::1::1424380312 +3::36::3::1424380312 +3::37::1::1424380312 +3::38::2::1424380312 +3::43::1::1424380312 +3::44::1::1424380312 +3::46::1::1424380312 +3::47::1::1424380312 +3::51::5::1424380312 +3::52::3::1424380312 +3::56::1::1424380312 +3::58::1::1424380312 +3::60::3::1424380312 +3::62::1::1424380312 +3::65::2::1424380312 +3::66::1::1424380312 +3::67::1::1424380312 +3::68::2::1424380312 +3::70::1::1424380312 +3::72::2::1424380312 +3::76::3::1424380312 +3::79::3::1424380312 +3::80::4::1424380312 +3::81::1::1424380312 +3::83::1::1424380312 +3::84::1::1424380312 +3::86::1::1424380312 +3::87::2::1424380312 +3::88::4::1424380312 +3::89::1::1424380312 +3::91::1::1424380312 +3::94::3::1424380312 +4::1::1::1424380312 +4::6::1::1424380312 +4::8::1::1424380312 +4::9::1::1424380312 +4::10::1::1424380312 +4::11::1::1424380312 +4::12::1::1424380312 +4::13::1::1424380312 +4::14::2::1424380312 +4::15::1::1424380312 +4::17::1::1424380312 +4::20::1::1424380312 +4::22::1::1424380312 +4::23::1::1424380312 +4::24::1::1424380312 +4::29::4::1424380312 +4::30::1::1424380312 +4::31::1::1424380312 +4::34::1::1424380312 +4::35::1::1424380312 +4::36::1::1424380312 +4::39::2::1424380312 +4::40::3::1424380312 +4::41::4::1424380312 +4::43::2::1424380312 +4::44::1::1424380312 +4::45::1::1424380312 +4::46::1::1424380312 +4::47::1::1424380312 +4::49::2::1424380312 +4::50::1::1424380312 +4::51::1::1424380312 +4::52::4::1424380312 +4::54::1::1424380312 +4::55::1::1424380312 +4::60::3::1424380312 +4::61::1::1424380312 +4::62::4::1424380312 +4::63::3::1424380312 +4::65::1::1424380312 +4::67::2::1424380312 +4::69::1::1424380312 +4::70::4::1424380312 +4::71::1::1424380312 +4::73::1::1424380312 +4::78::1::1424380312 +4::84::1::1424380312 +4::85::1::1424380312 +4::87::3::1424380312 +4::88::3::1424380312 +4::89::2::1424380312 +4::96::1::1424380312 +4::97::1::1424380312 +4::98::1::1424380312 +4::99::1::1424380312 +5::0::1::1424380312 +5::1::1::1424380312 +5::4::1::1424380312 +5::5::1::1424380312 +5::8::1::1424380312 +5::9::3::1424380312 +5::10::2::1424380312 +5::13::3::1424380312 +5::15::1::1424380312 +5::19::1::1424380312 +5::20::3::1424380312 +5::21::2::1424380312 +5::23::3::1424380312 +5::27::1::1424380312 +5::28::1::1424380312 +5::29::1::1424380312 +5::31::1::1424380312 +5::36::3::1424380312 +5::38::2::1424380312 +5::39::1::1424380312 +5::42::1::1424380312 +5::48::3::1424380312 +5::49::4::1424380312 +5::50::3::1424380312 +5::51::1::1424380312 +5::52::1::1424380312 +5::54::1::1424380312 +5::55::5::1424380312 +5::56::3::1424380312 +5::58::1::1424380312 +5::60::1::1424380312 +5::61::1::1424380312 +5::64::3::1424380312 +5::65::2::1424380312 +5::68::4::1424380312 +5::70::1::1424380312 +5::71::1::1424380312 +5::72::1::1424380312 +5::74::1::1424380312 +5::79::1::1424380312 +5::81::2::1424380312 +5::84::1::1424380312 +5::85::1::1424380312 +5::86::1::1424380312 +5::88::1::1424380312 +5::90::4::1424380312 +5::91::2::1424380312 +5::95::2::1424380312 +5::99::1::1424380312 +6::0::1::1424380312 +6::1::1::1424380312 +6::2::3::1424380312 +6::5::1::1424380312 +6::6::1::1424380312 +6::9::1::1424380312 +6::10::1::1424380312 +6::15::2::1424380312 +6::16::2::1424380312 +6::17::1::1424380312 +6::18::1::1424380312 +6::20::1::1424380312 +6::21::1::1424380312 +6::22::1::1424380312 +6::24::1::1424380312 +6::25::5::1424380312 +6::26::1::1424380312 +6::28::1::1424380312 +6::30::1::1424380312 +6::33::1::1424380312 +6::38::1::1424380312 +6::39::1::1424380312 +6::43::4::1424380312 +6::44::1::1424380312 +6::45::1::1424380312 +6::48::1::1424380312 +6::49::1::1424380312 +6::50::1::1424380312 +6::53::1::1424380312 +6::54::1::1424380312 +6::55::1::1424380312 +6::56::1::1424380312 +6::58::4::1424380312 +6::59::1::1424380312 +6::60::1::1424380312 +6::61::3::1424380312 +6::63::3::1424380312 +6::66::1::1424380312 +6::67::3::1424380312 +6::68::1::1424380312 +6::69::1::1424380312 +6::71::2::1424380312 +6::73::1::1424380312 +6::75::1::1424380312 +6::77::1::1424380312 +6::79::1::1424380312 +6::81::1::1424380312 +6::84::1::1424380312 +6::85::3::1424380312 +6::86::1::1424380312 +6::87::1::1424380312 +6::88::1::1424380312 +6::89::1::1424380312 +6::91::2::1424380312 +6::94::1::1424380312 +6::95::2::1424380312 +6::96::1::1424380312 +7::1::1::1424380312 +7::2::2::1424380312 +7::3::1::1424380312 +7::4::1::1424380312 +7::7::1::1424380312 +7::10::1::1424380312 +7::11::2::1424380312 +7::14::2::1424380312 +7::15::1::1424380312 +7::16::1::1424380312 +7::18::1::1424380312 +7::21::1::1424380312 +7::22::1::1424380312 +7::23::1::1424380312 +7::25::5::1424380312 +7::26::1::1424380312 +7::29::4::1424380312 +7::30::1::1424380312 +7::31::3::1424380312 +7::32::1::1424380312 +7::33::1::1424380312 +7::35::1::1424380312 +7::37::2::1424380312 +7::39::3::1424380312 +7::40::2::1424380312 +7::42::2::1424380312 +7::44::1::1424380312 +7::45::2::1424380312 +7::47::4::1424380312 +7::48::1::1424380312 +7::49::1::1424380312 +7::53::1::1424380312 +7::54::1::1424380312 +7::55::1::1424380312 +7::56::1::1424380312 +7::59::1::1424380312 +7::61::2::1424380312 +7::62::3::1424380312 +7::63::2::1424380312 +7::66::1::1424380312 +7::67::3::1424380312 +7::74::1::1424380312 +7::75::1::1424380312 +7::76::3::1424380312 +7::77::1::1424380312 +7::81::1::1424380312 +7::82::1::1424380312 +7::84::2::1424380312 +7::85::4::1424380312 +7::86::1::1424380312 +7::92::2::1424380312 +7::96::1::1424380312 +7::97::1::1424380312 +7::98::1::1424380312 +8::0::1::1424380312 +8::2::4::1424380312 +8::3::2::1424380312 +8::4::2::1424380312 +8::5::1::1424380312 +8::7::1::1424380312 +8::9::1::1424380312 +8::11::1::1424380312 +8::15::1::1424380312 +8::18::1::1424380312 +8::19::1::1424380312 +8::21::1::1424380312 +8::29::5::1424380312 +8::31::3::1424380312 +8::33::1::1424380312 +8::35::1::1424380312 +8::36::1::1424380312 +8::40::2::1424380312 +8::44::1::1424380312 +8::45::1::1424380312 +8::50::1::1424380312 +8::51::1::1424380312 +8::52::5::1424380312 +8::53::5::1424380312 +8::54::1::1424380312 +8::55::1::1424380312 +8::56::1::1424380312 +8::58::4::1424380312 +8::60::3::1424380312 +8::62::4::1424380312 +8::64::1::1424380312 +8::67::3::1424380312 +8::69::1::1424380312 +8::71::1::1424380312 +8::72::3::1424380312 +8::77::3::1424380312 +8::78::1::1424380312 +8::79::1::1424380312 +8::83::1::1424380312 +8::85::5::1424380312 +8::86::1::1424380312 +8::88::1::1424380312 +8::90::1::1424380312 +8::92::2::1424380312 +8::95::4::1424380312 +8::96::3::1424380312 +8::97::1::1424380312 +8::98::1::1424380312 +8::99::1::1424380312 +9::2::3::1424380312 +9::3::1::1424380312 +9::4::1::1424380312 +9::5::1::1424380312 +9::6::1::1424380312 +9::7::5::1424380312 +9::9::1::1424380312 +9::12::1::1424380312 +9::14::3::1424380312 +9::15::1::1424380312 +9::19::1::1424380312 +9::21::1::1424380312 +9::22::1::1424380312 +9::24::1::1424380312 +9::25::1::1424380312 +9::26::1::1424380312 +9::30::3::1424380312 +9::32::4::1424380312 +9::35::2::1424380312 +9::36::2::1424380312 +9::37::2::1424380312 +9::38::1::1424380312 +9::39::1::1424380312 +9::43::3::1424380312 +9::49::5::1424380312 +9::50::3::1424380312 +9::53::1::1424380312 +9::54::1::1424380312 +9::58::1::1424380312 +9::59::1::1424380312 +9::60::1::1424380312 +9::61::1::1424380312 +9::63::3::1424380312 +9::64::3::1424380312 +9::68::1::1424380312 +9::69::1::1424380312 +9::70::3::1424380312 +9::71::1::1424380312 +9::73::2::1424380312 +9::75::1::1424380312 +9::77::2::1424380312 +9::81::2::1424380312 +9::82::1::1424380312 +9::83::1::1424380312 +9::84::1::1424380312 +9::86::1::1424380312 +9::87::4::1424380312 +9::88::1::1424380312 +9::90::3::1424380312 +9::94::2::1424380312 +9::95::3::1424380312 +9::97::2::1424380312 +9::98::1::1424380312 +10::0::3::1424380312 +10::2::4::1424380312 +10::4::3::1424380312 +10::7::1::1424380312 +10::8::1::1424380312 +10::10::1::1424380312 +10::13::2::1424380312 +10::14::1::1424380312 +10::16::2::1424380312 +10::17::1::1424380312 +10::18::1::1424380312 +10::21::1::1424380312 +10::22::1::1424380312 +10::24::1::1424380312 +10::25::3::1424380312 +10::28::1::1424380312 +10::35::1::1424380312 +10::36::1::1424380312 +10::37::1::1424380312 +10::38::1::1424380312 +10::39::1::1424380312 +10::40::4::1424380312 +10::41::2::1424380312 +10::42::3::1424380312 +10::43::1::1424380312 +10::49::3::1424380312 +10::50::1::1424380312 +10::51::1::1424380312 +10::52::1::1424380312 +10::55::2::1424380312 +10::56::1::1424380312 +10::58::1::1424380312 +10::63::1::1424380312 +10::66::1::1424380312 +10::67::2::1424380312 +10::68::1::1424380312 +10::75::1::1424380312 +10::77::1::1424380312 +10::79::1::1424380312 +10::86::1::1424380312 +10::89::3::1424380312 +10::90::1::1424380312 +10::97::1::1424380312 +10::98::1::1424380312 +11::0::1::1424380312 +11::6::2::1424380312 +11::9::1::1424380312 +11::10::1::1424380312 +11::11::1::1424380312 +11::12::1::1424380312 +11::13::4::1424380312 +11::16::1::1424380312 +11::18::5::1424380312 +11::19::4::1424380312 +11::20::1::1424380312 +11::21::1::1424380312 +11::22::1::1424380312 +11::23::5::1424380312 +11::25::1::1424380312 +11::27::5::1424380312 +11::30::5::1424380312 +11::32::5::1424380312 +11::35::3::1424380312 +11::36::2::1424380312 +11::37::2::1424380312 +11::38::4::1424380312 +11::39::1::1424380312 +11::40::1::1424380312 +11::41::1::1424380312 +11::43::2::1424380312 +11::45::1::1424380312 +11::47::1::1424380312 +11::48::5::1424380312 +11::50::4::1424380312 +11::51::3::1424380312 +11::59::1::1424380312 +11::61::1::1424380312 +11::62::1::1424380312 +11::64::1::1424380312 +11::66::4::1424380312 +11::67::1::1424380312 +11::69::5::1424380312 +11::70::1::1424380312 +11::71::3::1424380312 +11::72::3::1424380312 +11::75::3::1424380312 +11::76::1::1424380312 +11::77::1::1424380312 +11::78::1::1424380312 +11::79::5::1424380312 +11::80::3::1424380312 +11::81::4::1424380312 +11::82::1::1424380312 +11::86::1::1424380312 +11::88::1::1424380312 +11::89::1::1424380312 +11::90::4::1424380312 +11::94::2::1424380312 +11::97::3::1424380312 +11::99::1::1424380312 +12::2::1::1424380312 +12::4::1::1424380312 +12::6::1::1424380312 +12::7::3::1424380312 +12::8::1::1424380312 +12::14::1::1424380312 +12::15::2::1424380312 +12::16::4::1424380312 +12::17::5::1424380312 +12::18::2::1424380312 +12::21::1::1424380312 +12::22::2::1424380312 +12::23::3::1424380312 +12::24::1::1424380312 +12::25::1::1424380312 +12::27::5::1424380312 +12::30::2::1424380312 +12::31::4::1424380312 +12::35::5::1424380312 +12::38::1::1424380312 +12::41::1::1424380312 +12::44::2::1424380312 +12::45::1::1424380312 +12::50::4::1424380312 +12::51::1::1424380312 +12::52::1::1424380312 +12::53::1::1424380312 +12::54::1::1424380312 +12::56::2::1424380312 +12::57::1::1424380312 +12::60::1::1424380312 +12::63::1::1424380312 +12::64::5::1424380312 +12::66::3::1424380312 +12::67::1::1424380312 +12::70::1::1424380312 +12::72::1::1424380312 +12::74::1::1424380312 +12::75::1::1424380312 +12::77::1::1424380312 +12::78::1::1424380312 +12::79::3::1424380312 +12::82::2::1424380312 +12::83::1::1424380312 +12::84::1::1424380312 +12::85::1::1424380312 +12::86::1::1424380312 +12::87::1::1424380312 +12::88::1::1424380312 +12::91::3::1424380312 +12::92::1::1424380312 +12::94::4::1424380312 +12::95::2::1424380312 +12::96::1::1424380312 +12::98::2::1424380312 +13::0::1::1424380312 +13::3::1::1424380312 +13::4::2::1424380312 +13::5::1::1424380312 +13::6::1::1424380312 +13::12::1::1424380312 +13::14::2::1424380312 +13::15::1::1424380312 +13::17::1::1424380312 +13::18::3::1424380312 +13::20::1::1424380312 +13::21::1::1424380312 +13::22::1::1424380312 +13::26::1::1424380312 +13::27::1::1424380312 +13::29::3::1424380312 +13::31::1::1424380312 +13::33::1::1424380312 +13::40::2::1424380312 +13::43::2::1424380312 +13::44::1::1424380312 +13::45::1::1424380312 +13::49::1::1424380312 +13::51::1::1424380312 +13::52::2::1424380312 +13::53::3::1424380312 +13::54::1::1424380312 +13::62::1::1424380312 +13::63::2::1424380312 +13::64::1::1424380312 +13::68::1::1424380312 +13::71::1::1424380312 +13::72::3::1424380312 +13::73::1::1424380312 +13::74::3::1424380312 +13::77::2::1424380312 +13::78::1::1424380312 +13::79::2::1424380312 +13::83::3::1424380312 +13::85::1::1424380312 +13::86::1::1424380312 +13::87::2::1424380312 +13::88::2::1424380312 +13::90::1::1424380312 +13::93::4::1424380312 +13::94::1::1424380312 +13::98::1::1424380312 +13::99::1::1424380312 +14::1::1::1424380312 +14::3::3::1424380312 +14::4::1::1424380312 +14::5::1::1424380312 +14::6::1::1424380312 +14::7::1::1424380312 +14::9::1::1424380312 +14::10::1::1424380312 +14::11::1::1424380312 +14::12::1::1424380312 +14::13::1::1424380312 +14::14::3::1424380312 +14::15::1::1424380312 +14::16::1::1424380312 +14::17::1::1424380312 +14::20::1::1424380312 +14::21::1::1424380312 +14::24::1::1424380312 +14::25::2::1424380312 +14::27::1::1424380312 +14::28::1::1424380312 +14::29::5::1424380312 +14::31::3::1424380312 +14::34::1::1424380312 +14::36::1::1424380312 +14::37::2::1424380312 +14::39::2::1424380312 +14::40::1::1424380312 +14::44::1::1424380312 +14::45::1::1424380312 +14::47::3::1424380312 +14::48::1::1424380312 +14::49::1::1424380312 +14::51::1::1424380312 +14::52::5::1424380312 +14::53::3::1424380312 +14::54::1::1424380312 +14::55::1::1424380312 +14::56::1::1424380312 +14::62::4::1424380312 +14::63::5::1424380312 +14::67::3::1424380312 +14::68::1::1424380312 +14::69::3::1424380312 +14::71::1::1424380312 +14::72::4::1424380312 +14::73::1::1424380312 +14::76::5::1424380312 +14::79::1::1424380312 +14::82::1::1424380312 +14::83::1::1424380312 +14::88::1::1424380312 +14::93::3::1424380312 +14::94::1::1424380312 +14::95::2::1424380312 +14::96::4::1424380312 +14::98::1::1424380312 +15::0::1::1424380312 +15::1::4::1424380312 +15::2::1::1424380312 +15::5::2::1424380312 +15::6::1::1424380312 +15::7::1::1424380312 +15::13::1::1424380312 +15::14::1::1424380312 +15::15::1::1424380312 +15::17::2::1424380312 +15::19::2::1424380312 +15::22::2::1424380312 +15::23::2::1424380312 +15::25::1::1424380312 +15::26::3::1424380312 +15::27::1::1424380312 +15::28::2::1424380312 +15::29::1::1424380312 +15::32::1::1424380312 +15::33::2::1424380312 +15::34::1::1424380312 +15::35::2::1424380312 +15::36::1::1424380312 +15::37::1::1424380312 +15::39::1::1424380312 +15::42::1::1424380312 +15::46::5::1424380312 +15::48::2::1424380312 +15::50::2::1424380312 +15::51::1::1424380312 +15::52::1::1424380312 +15::58::1::1424380312 +15::62::1::1424380312 +15::64::3::1424380312 +15::65::2::1424380312 +15::72::1::1424380312 +15::73::1::1424380312 +15::74::1::1424380312 +15::79::1::1424380312 +15::80::1::1424380312 +15::81::1::1424380312 +15::82::2::1424380312 +15::85::1::1424380312 +15::87::1::1424380312 +15::91::2::1424380312 +15::96::1::1424380312 +15::97::1::1424380312 +15::98::3::1424380312 +16::2::1::1424380312 +16::5::3::1424380312 +16::6::2::1424380312 +16::7::1::1424380312 +16::9::1::1424380312 +16::12::1::1424380312 +16::14::1::1424380312 +16::15::1::1424380312 +16::19::1::1424380312 +16::21::2::1424380312 +16::29::4::1424380312 +16::30::2::1424380312 +16::32::1::1424380312 +16::34::1::1424380312 +16::36::1::1424380312 +16::38::1::1424380312 +16::46::1::1424380312 +16::47::3::1424380312 +16::48::1::1424380312 +16::49::1::1424380312 +16::50::1::1424380312 +16::51::5::1424380312 +16::54::5::1424380312 +16::55::1::1424380312 +16::56::2::1424380312 +16::57::1::1424380312 +16::60::1::1424380312 +16::63::2::1424380312 +16::65::1::1424380312 +16::67::1::1424380312 +16::72::1::1424380312 +16::74::1::1424380312 +16::80::1::1424380312 +16::81::1::1424380312 +16::82::1::1424380312 +16::85::5::1424380312 +16::86::1::1424380312 +16::90::5::1424380312 +16::91::1::1424380312 +16::93::1::1424380312 +16::94::3::1424380312 +16::95::2::1424380312 +16::96::3::1424380312 +16::98::3::1424380312 +16::99::1::1424380312 +17::2::1::1424380312 +17::3::1::1424380312 +17::6::1::1424380312 +17::10::4::1424380312 +17::11::1::1424380312 +17::13::2::1424380312 +17::17::5::1424380312 +17::19::1::1424380312 +17::20::5::1424380312 +17::22::4::1424380312 +17::28::1::1424380312 +17::29::1::1424380312 +17::33::1::1424380312 +17::34::1::1424380312 +17::35::2::1424380312 +17::37::1::1424380312 +17::38::1::1424380312 +17::45::1::1424380312 +17::46::5::1424380312 +17::47::1::1424380312 +17::49::3::1424380312 +17::51::1::1424380312 +17::55::5::1424380312 +17::56::3::1424380312 +17::57::1::1424380312 +17::58::1::1424380312 +17::59::1::1424380312 +17::60::1::1424380312 +17::63::1::1424380312 +17::66::1::1424380312 +17::68::4::1424380312 +17::69::1::1424380312 +17::70::1::1424380312 +17::72::1::1424380312 +17::73::3::1424380312 +17::78::1::1424380312 +17::79::1::1424380312 +17::82::2::1424380312 +17::84::1::1424380312 +17::90::5::1424380312 +17::91::3::1424380312 +17::92::1::1424380312 +17::93::1::1424380312 +17::94::4::1424380312 +17::95::2::1424380312 +17::97::1::1424380312 +18::1::1::1424380312 +18::4::3::1424380312 +18::5::2::1424380312 +18::6::1::1424380312 +18::7::1::1424380312 +18::10::1::1424380312 +18::11::4::1424380312 +18::12::2::1424380312 +18::13::1::1424380312 +18::15::1::1424380312 +18::18::1::1424380312 +18::20::1::1424380312 +18::21::2::1424380312 +18::22::1::1424380312 +18::23::2::1424380312 +18::25::1::1424380312 +18::26::1::1424380312 +18::27::1::1424380312 +18::28::5::1424380312 +18::29::1::1424380312 +18::31::1::1424380312 +18::32::1::1424380312 +18::36::1::1424380312 +18::38::5::1424380312 +18::39::5::1424380312 +18::40::1::1424380312 +18::42::1::1424380312 +18::43::1::1424380312 +18::44::4::1424380312 +18::46::1::1424380312 +18::47::1::1424380312 +18::48::1::1424380312 +18::51::2::1424380312 +18::55::1::1424380312 +18::56::1::1424380312 +18::57::1::1424380312 +18::62::1::1424380312 +18::63::1::1424380312 +18::66::3::1424380312 +18::67::1::1424380312 +18::70::1::1424380312 +18::75::1::1424380312 +18::76::3::1424380312 +18::77::1::1424380312 +18::80::3::1424380312 +18::81::3::1424380312 +18::82::1::1424380312 +18::83::5::1424380312 +18::84::1::1424380312 +18::97::1::1424380312 +18::98::1::1424380312 +18::99::2::1424380312 +19::0::1::1424380312 +19::1::1::1424380312 +19::2::1::1424380312 +19::4::1::1424380312 +19::6::2::1424380312 +19::11::1::1424380312 +19::12::1::1424380312 +19::14::1::1424380312 +19::23::1::1424380312 +19::26::1::1424380312 +19::31::1::1424380312 +19::32::4::1424380312 +19::33::1::1424380312 +19::34::1::1424380312 +19::37::1::1424380312 +19::38::1::1424380312 +19::41::1::1424380312 +19::43::1::1424380312 +19::45::1::1424380312 +19::48::1::1424380312 +19::49::1::1424380312 +19::50::2::1424380312 +19::53::2::1424380312 +19::54::3::1424380312 +19::55::1::1424380312 +19::56::2::1424380312 +19::58::1::1424380312 +19::61::1::1424380312 +19::62::1::1424380312 +19::63::1::1424380312 +19::64::1::1424380312 +19::65::1::1424380312 +19::69::2::1424380312 +19::72::1::1424380312 +19::74::3::1424380312 +19::76::1::1424380312 +19::78::1::1424380312 +19::79::1::1424380312 +19::81::1::1424380312 +19::82::1::1424380312 +19::84::1::1424380312 +19::86::1::1424380312 +19::87::2::1424380312 +19::90::4::1424380312 +19::93::1::1424380312 +19::94::4::1424380312 +19::95::2::1424380312 +19::96::1::1424380312 +19::98::4::1424380312 +20::0::1::1424380312 +20::1::1::1424380312 +20::2::2::1424380312 +20::4::2::1424380312 +20::6::1::1424380312 +20::8::1::1424380312 +20::12::1::1424380312 +20::21::2::1424380312 +20::22::5::1424380312 +20::24::2::1424380312 +20::25::1::1424380312 +20::26::1::1424380312 +20::29::2::1424380312 +20::30::2::1424380312 +20::32::2::1424380312 +20::39::1::1424380312 +20::40::1::1424380312 +20::41::2::1424380312 +20::45::2::1424380312 +20::48::1::1424380312 +20::50::1::1424380312 +20::51::3::1424380312 +20::53::3::1424380312 +20::55::1::1424380312 +20::57::2::1424380312 +20::60::1::1424380312 +20::61::1::1424380312 +20::64::1::1424380312 +20::66::1::1424380312 +20::70::2::1424380312 +20::72::1::1424380312 +20::73::2::1424380312 +20::75::4::1424380312 +20::76::1::1424380312 +20::77::4::1424380312 +20::78::1::1424380312 +20::79::1::1424380312 +20::84::2::1424380312 +20::85::2::1424380312 +20::88::3::1424380312 +20::89::1::1424380312 +20::90::3::1424380312 +20::91::1::1424380312 +20::92::2::1424380312 +20::93::1::1424380312 +20::94::4::1424380312 +20::97::1::1424380312 +21::0::1::1424380312 +21::2::4::1424380312 +21::3::1::1424380312 +21::7::2::1424380312 +21::11::1::1424380312 +21::12::1::1424380312 +21::13::1::1424380312 +21::14::3::1424380312 +21::17::1::1424380312 +21::19::1::1424380312 +21::20::1::1424380312 +21::21::1::1424380312 +21::22::1::1424380312 +21::23::1::1424380312 +21::24::1::1424380312 +21::27::1::1424380312 +21::29::5::1424380312 +21::30::2::1424380312 +21::38::1::1424380312 +21::40::2::1424380312 +21::43::3::1424380312 +21::44::1::1424380312 +21::45::1::1424380312 +21::46::1::1424380312 +21::48::1::1424380312 +21::51::1::1424380312 +21::53::5::1424380312 +21::54::1::1424380312 +21::55::1::1424380312 +21::56::1::1424380312 +21::58::3::1424380312 +21::59::3::1424380312 +21::64::1::1424380312 +21::66::1::1424380312 +21::68::1::1424380312 +21::71::1::1424380312 +21::73::1::1424380312 +21::74::4::1424380312 +21::80::1::1424380312 +21::81::1::1424380312 +21::83::1::1424380312 +21::84::1::1424380312 +21::85::3::1424380312 +21::87::4::1424380312 +21::89::2::1424380312 +21::92::2::1424380312 +21::96::3::1424380312 +21::99::1::1424380312 +22::0::1::1424380312 +22::3::2::1424380312 +22::5::2::1424380312 +22::6::2::1424380312 +22::9::1::1424380312 +22::10::1::1424380312 +22::11::1::1424380312 +22::13::1::1424380312 +22::14::1::1424380312 +22::16::1::1424380312 +22::18::3::1424380312 +22::19::1::1424380312 +22::22::5::1424380312 +22::25::1::1424380312 +22::26::1::1424380312 +22::29::3::1424380312 +22::30::5::1424380312 +22::32::4::1424380312 +22::33::1::1424380312 +22::35::1::1424380312 +22::36::3::1424380312 +22::37::1::1424380312 +22::40::1::1424380312 +22::41::3::1424380312 +22::44::1::1424380312 +22::45::2::1424380312 +22::48::1::1424380312 +22::51::5::1424380312 +22::55::1::1424380312 +22::56::2::1424380312 +22::60::3::1424380312 +22::61::1::1424380312 +22::62::4::1424380312 +22::63::1::1424380312 +22::65::1::1424380312 +22::66::1::1424380312 +22::68::4::1424380312 +22::69::4::1424380312 +22::70::3::1424380312 +22::71::1::1424380312 +22::74::5::1424380312 +22::75::5::1424380312 +22::78::1::1424380312 +22::80::3::1424380312 +22::81::1::1424380312 +22::82::1::1424380312 +22::84::1::1424380312 +22::86::1::1424380312 +22::87::3::1424380312 +22::88::5::1424380312 +22::90::2::1424380312 +22::92::3::1424380312 +22::95::2::1424380312 +22::96::2::1424380312 +22::98::4::1424380312 +22::99::1::1424380312 +23::0::1::1424380312 +23::2::1::1424380312 +23::4::1::1424380312 +23::6::2::1424380312 +23::10::4::1424380312 +23::12::1::1424380312 +23::13::4::1424380312 +23::14::1::1424380312 +23::15::1::1424380312 +23::18::4::1424380312 +23::22::2::1424380312 +23::23::4::1424380312 +23::24::1::1424380312 +23::25::1::1424380312 +23::26::1::1424380312 +23::27::5::1424380312 +23::28::1::1424380312 +23::29::1::1424380312 +23::30::4::1424380312 +23::32::5::1424380312 +23::33::2::1424380312 +23::36::3::1424380312 +23::37::1::1424380312 +23::38::1::1424380312 +23::39::1::1424380312 +23::43::1::1424380312 +23::48::5::1424380312 +23::49::5::1424380312 +23::50::4::1424380312 +23::53::1::1424380312 +23::55::5::1424380312 +23::57::1::1424380312 +23::59::1::1424380312 +23::60::1::1424380312 +23::61::1::1424380312 +23::64::4::1424380312 +23::65::5::1424380312 +23::66::2::1424380312 +23::67::1::1424380312 +23::68::3::1424380312 +23::69::1::1424380312 +23::72::1::1424380312 +23::73::3::1424380312 +23::77::1::1424380312 +23::82::2::1424380312 +23::83::1::1424380312 +23::84::1::1424380312 +23::85::1::1424380312 +23::87::3::1424380312 +23::88::1::1424380312 +23::95::2::1424380312 +23::97::1::1424380312 +24::4::1::1424380312 +24::6::3::1424380312 +24::7::1::1424380312 +24::10::2::1424380312 +24::12::1::1424380312 +24::15::1::1424380312 +24::19::1::1424380312 +24::24::1::1424380312 +24::27::3::1424380312 +24::30::5::1424380312 +24::31::1::1424380312 +24::32::3::1424380312 +24::33::1::1424380312 +24::37::1::1424380312 +24::39::1::1424380312 +24::40::1::1424380312 +24::42::1::1424380312 +24::43::3::1424380312 +24::45::2::1424380312 +24::46::1::1424380312 +24::47::1::1424380312 +24::48::1::1424380312 +24::49::1::1424380312 +24::50::1::1424380312 +24::52::5::1424380312 +24::57::1::1424380312 +24::59::4::1424380312 +24::63::4::1424380312 +24::65::1::1424380312 +24::66::1::1424380312 +24::67::1::1424380312 +24::68::3::1424380312 +24::69::5::1424380312 +24::71::1::1424380312 +24::72::4::1424380312 +24::77::4::1424380312 +24::78::1::1424380312 +24::80::1::1424380312 +24::82::1::1424380312 +24::84::1::1424380312 +24::86::1::1424380312 +24::87::1::1424380312 +24::88::2::1424380312 +24::89::1::1424380312 +24::90::5::1424380312 +24::91::1::1424380312 +24::92::1::1424380312 +24::94::2::1424380312 +24::95::1::1424380312 +24::96::5::1424380312 +24::98::1::1424380312 +24::99::1::1424380312 +25::1::3::1424380312 +25::2::1::1424380312 +25::7::1::1424380312 +25::9::1::1424380312 +25::12::3::1424380312 +25::16::3::1424380312 +25::17::1::1424380312 +25::18::1::1424380312 +25::20::1::1424380312 +25::22::1::1424380312 +25::23::1::1424380312 +25::26::2::1424380312 +25::29::1::1424380312 +25::30::1::1424380312 +25::31::2::1424380312 +25::33::4::1424380312 +25::34::3::1424380312 +25::35::2::1424380312 +25::36::1::1424380312 +25::37::1::1424380312 +25::40::1::1424380312 +25::41::1::1424380312 +25::43::1::1424380312 +25::47::4::1424380312 +25::50::1::1424380312 +25::51::1::1424380312 +25::53::1::1424380312 +25::56::1::1424380312 +25::58::2::1424380312 +25::64::2::1424380312 +25::67::2::1424380312 +25::68::1::1424380312 +25::70::1::1424380312 +25::71::4::1424380312 +25::73::1::1424380312 +25::74::1::1424380312 +25::76::1::1424380312 +25::79::1::1424380312 +25::82::1::1424380312 +25::84::2::1424380312 +25::85::1::1424380312 +25::91::3::1424380312 +25::92::1::1424380312 +25::94::1::1424380312 +25::95::1::1424380312 +25::97::2::1424380312 +26::0::1::1424380312 +26::1::1::1424380312 +26::2::1::1424380312 +26::3::1::1424380312 +26::4::4::1424380312 +26::5::2::1424380312 +26::6::3::1424380312 +26::7::5::1424380312 +26::13::3::1424380312 +26::14::1::1424380312 +26::16::1::1424380312 +26::18::3::1424380312 +26::20::1::1424380312 +26::21::3::1424380312 +26::22::5::1424380312 +26::23::5::1424380312 +26::24::5::1424380312 +26::27::1::1424380312 +26::31::1::1424380312 +26::35::1::1424380312 +26::36::4::1424380312 +26::40::1::1424380312 +26::44::1::1424380312 +26::45::2::1424380312 +26::47::1::1424380312 +26::48::1::1424380312 +26::49::3::1424380312 +26::50::2::1424380312 +26::52::1::1424380312 +26::54::4::1424380312 +26::55::1::1424380312 +26::57::3::1424380312 +26::58::1::1424380312 +26::61::1::1424380312 +26::62::2::1424380312 +26::66::1::1424380312 +26::68::4::1424380312 +26::71::1::1424380312 +26::73::4::1424380312 +26::76::1::1424380312 +26::81::3::1424380312 +26::85::1::1424380312 +26::86::3::1424380312 +26::88::5::1424380312 +26::91::1::1424380312 +26::94::5::1424380312 +26::95::1::1424380312 +26::96::1::1424380312 +26::97::1::1424380312 +27::0::1::1424380312 +27::9::1::1424380312 +27::10::1::1424380312 +27::18::4::1424380312 +27::19::3::1424380312 +27::20::1::1424380312 +27::22::2::1424380312 +27::24::2::1424380312 +27::25::1::1424380312 +27::27::3::1424380312 +27::28::1::1424380312 +27::29::1::1424380312 +27::31::1::1424380312 +27::33::3::1424380312 +27::40::1::1424380312 +27::42::1::1424380312 +27::43::1::1424380312 +27::44::3::1424380312 +27::45::1::1424380312 +27::51::3::1424380312 +27::52::1::1424380312 +27::55::3::1424380312 +27::57::1::1424380312 +27::59::1::1424380312 +27::60::1::1424380312 +27::61::1::1424380312 +27::64::1::1424380312 +27::66::3::1424380312 +27::68::1::1424380312 +27::70::1::1424380312 +27::71::2::1424380312 +27::72::1::1424380312 +27::75::3::1424380312 +27::78::1::1424380312 +27::80::3::1424380312 +27::82::1::1424380312 +27::83::3::1424380312 +27::86::1::1424380312 +27::87::2::1424380312 +27::90::1::1424380312 +27::91::1::1424380312 +27::92::1::1424380312 +27::93::1::1424380312 +27::94::2::1424380312 +27::95::1::1424380312 +27::98::1::1424380312 +28::0::3::1424380312 +28::1::1::1424380312 +28::2::4::1424380312 +28::3::1::1424380312 +28::6::1::1424380312 +28::7::1::1424380312 +28::12::5::1424380312 +28::13::2::1424380312 +28::14::1::1424380312 +28::15::1::1424380312 +28::17::1::1424380312 +28::19::3::1424380312 +28::20::1::1424380312 +28::23::3::1424380312 +28::24::3::1424380312 +28::27::1::1424380312 +28::29::1::1424380312 +28::33::1::1424380312 +28::34::1::1424380312 +28::36::1::1424380312 +28::38::2::1424380312 +28::39::2::1424380312 +28::44::1::1424380312 +28::45::1::1424380312 +28::49::4::1424380312 +28::50::1::1424380312 +28::52::1::1424380312 +28::54::1::1424380312 +28::56::1::1424380312 +28::57::3::1424380312 +28::58::1::1424380312 +28::59::1::1424380312 +28::60::1::1424380312 +28::62::3::1424380312 +28::63::1::1424380312 +28::65::1::1424380312 +28::75::1::1424380312 +28::78::1::1424380312 +28::81::5::1424380312 +28::82::4::1424380312 +28::83::1::1424380312 +28::85::1::1424380312 +28::88::2::1424380312 +28::89::4::1424380312 +28::90::1::1424380312 +28::92::5::1424380312 +28::94::1::1424380312 +28::95::2::1424380312 +28::98::1::1424380312 +28::99::1::1424380312 +29::3::1::1424380312 +29::4::1::1424380312 +29::5::1::1424380312 +29::7::2::1424380312 +29::9::1::1424380312 +29::10::3::1424380312 +29::11::1::1424380312 +29::13::3::1424380312 +29::14::1::1424380312 +29::15::1::1424380312 +29::17::3::1424380312 +29::19::3::1424380312 +29::22::3::1424380312 +29::23::4::1424380312 +29::25::1::1424380312 +29::29::1::1424380312 +29::31::1::1424380312 +29::32::4::1424380312 +29::33::2::1424380312 +29::36::2::1424380312 +29::38::3::1424380312 +29::39::1::1424380312 +29::42::1::1424380312 +29::46::5::1424380312 +29::49::3::1424380312 +29::51::2::1424380312 +29::59::1::1424380312 +29::61::1::1424380312 +29::62::1::1424380312 +29::67::1::1424380312 +29::68::3::1424380312 +29::69::1::1424380312 +29::70::1::1424380312 +29::74::1::1424380312 +29::75::1::1424380312 +29::79::2::1424380312 +29::80::1::1424380312 +29::81::2::1424380312 +29::83::1::1424380312 +29::85::1::1424380312 +29::86::1::1424380312 +29::90::4::1424380312 +29::93::1::1424380312 +29::94::4::1424380312 +29::97::1::1424380312 +29::99::1::1424380312 diff --git a/data/mllib/gmm_data.txt b/data/mllib/gmm_data.txt new file mode 100644 index 0000000000000..934ee4a83a2df --- /dev/null +++ b/data/mllib/gmm_data.txt @@ -0,0 +1,2000 @@ + 2.59470454e+00 2.12298217e+00 + 1.15807024e+00 -1.46498723e-01 + 2.46206638e+00 6.19556894e-01 + -5.54845070e-01 -7.24700066e-01 + -3.23111426e+00 -1.42579084e+00 + 3.02978115e+00 7.87121753e-01 + 1.97365907e+00 1.15914704e+00 + -6.44852101e+00 -3.18154314e+00 + 1.30963349e+00 1.62866434e-01 + 4.26482541e+00 2.15547996e+00 + 3.79927257e+00 1.50572445e+00 + 4.17452609e-01 -6.74032760e-01 + 4.21117627e-01 4.45590255e-01 + -2.80425571e+00 -7.77150554e-01 + 2.55928797e+00 7.03954218e-01 + 1.32554059e+00 -9.46663152e-01 + -3.39691439e+00 -1.49005743e+00 + -2.26542270e-01 3.60052515e-02 + 1.04994198e+00 5.29825685e-01 + -1.51566882e+00 -1.86264432e-01 + -3.27928172e-01 -7.60859110e-01 + -3.18054866e-01 3.97719805e-01 + 1.65579418e-01 -3.47232033e-01 + 6.47162333e-01 4.96059961e-02 + -2.80776647e-01 4.79418757e-01 + 7.45069752e-01 1.20790281e-01 + 2.13604102e-01 1.59542555e-01 + -3.08860224e+00 -1.43259870e+00 + 8.97066497e-01 1.10206801e+00 + -2.23918874e-01 -1.07267267e+00 + 2.51525708e+00 2.84761973e-01 + 9.98052532e-01 1.08333783e+00 + 1.76705588e+00 8.18866778e-01 + 5.31555163e-02 -1.90111151e-01 + -2.17405059e+00 7.21854582e-02 + -2.13772505e+00 -3.62010387e-01 + 2.95974057e+00 1.31602381e+00 + 2.74053561e+00 1.61781757e+00 + 6.68135448e-01 2.86586009e-01 + 2.82323739e+00 1.74437257e+00 + 8.11540288e-01 5.50744478e-01 + 4.10050897e-01 5.10668402e-03 + 9.58626136e-01 -3.49633680e-01 + 4.66599798e+00 1.49964894e+00 + 4.94507794e-01 2.58928077e-01 + -2.36029742e+00 -1.61042909e+00 + -4.99306804e-01 -8.04984769e-01 + 1.07448510e+00 9.39605828e-01 + -1.80448949e+00 -1.05983264e+00 + -3.22353821e-01 1.73612093e-01 + 1.85418702e+00 1.15640643e+00 + 6.93794163e-01 6.59993560e-01 + 1.99399102e+00 1.44547123e+00 + 3.38866124e+00 1.23379290e+00 + -4.24067720e+00 -1.22264282e+00 + 6.03230201e-02 2.95232729e-01 + -3.59341813e+00 -7.17453726e-01 + 4.87447372e-01 -2.00733911e-01 + 1.20149195e+00 4.07880197e-01 + -2.13331464e+00 -4.58518077e-01 + -3.84091083e+00 -1.71553950e+00 + -5.37279250e-01 2.64822629e-02 + -2.10155227e+00 -1.32558103e+00 + -1.71318897e+00 -7.12098563e-01 + -1.46280695e+00 -1.84868337e-01 + -3.59785325e+00 -1.54832434e+00 + -5.77528081e-01 -5.78580857e-01 + 3.14734283e-01 5.80184639e-01 + -2.71164714e+00 -1.19379432e+00 + 1.09634489e+00 7.20143887e-01 + -3.05527722e+00 -1.47774064e+00 + 6.71753586e-01 7.61350020e-01 + 3.98294144e+00 1.54166484e+00 + -3.37220384e+00 -2.21332064e+00 + 1.81222914e+00 7.41212752e-01 + 2.71458282e-01 1.36329078e-01 + -3.97815359e-01 1.16766886e-01 + -1.70192814e+00 -9.75851571e-01 + -3.46803804e+00 -1.09965988e+00 + -1.69649627e+00 -5.76045801e-01 + -1.02485636e-01 -8.81841246e-01 + -3.24194667e-02 2.55429276e-01 + -2.75343168e+00 -1.51366320e+00 + -2.78676702e+00 -5.22360489e-01 + 1.70483164e+00 1.19769805e+00 + 4.92022579e-01 3.24944706e-01 + 2.48768464e+00 1.00055363e+00 + 4.48786400e-01 7.63902870e-01 + 2.93862696e+00 1.73809968e+00 + -3.55019305e+00 -1.97875558e+00 + 1.74270784e+00 6.90229224e-01 + 5.13391994e-01 4.58374016e-01 + 1.78379499e+00 9.08026381e-01 + 1.75814147e+00 7.41449784e-01 + -2.30687792e-01 3.91009729e-01 + 3.92271353e+00 1.44006290e+00 + 2.93361679e-01 -4.99886375e-03 + 2.47902690e-01 -7.49542503e-01 + -3.97675355e-01 1.36824887e-01 + 3.56535953e+00 1.15181329e+00 + 3.22425301e+00 1.28702383e+00 + -2.94192478e-01 -2.42382557e-01 + 8.02068864e-01 -1.51671475e-01 + 8.54133530e-01 -4.89514885e-02 + -1.64316316e-01 -5.34642346e-01 + -6.08485405e-01 -2.10332352e-01 + -2.18940059e+00 -1.07024952e+00 + -1.71586960e+00 -2.83333492e-02 + 1.70200448e-01 -3.28031178e-01 + -1.97210346e+00 -5.39948532e-01 + 2.19500160e+00 1.05697170e+00 + -1.76239935e+00 -1.09377438e+00 + 1.68314744e+00 6.86491164e-01 + -2.99852288e+00 -1.46619067e+00 + -2.23769560e+00 -9.15008355e-01 + 9.46887516e-01 5.58410503e-01 + 5.02153123e-01 1.63851235e-01 + -9.70297062e-01 3.14625374e-01 + -1.29405593e+00 -8.20994131e-01 + 2.72516079e+00 7.85839947e-01 + 1.45788024e+00 3.37487353e-01 + -4.36292749e-01 -5.42150480e-01 + 2.21304711e+00 1.25254042e+00 + -1.20810271e-01 4.79632898e-01 + -3.30884511e+00 -1.50607586e+00 + -6.55882455e+00 -1.94231256e+00 + -3.17033630e+00 -9.94678930e-01 + 1.42043617e+00 7.28808957e-01 + -1.57546099e+00 -1.10320497e+00 + -3.22748754e+00 -1.64174579e+00 + 2.96776017e-03 -3.16191512e-02 + -2.25986054e+00 -6.13123197e-01 + 2.49434243e+00 7.73069183e-01 + 9.08494049e-01 -1.53926853e-01 + -2.80559090e+00 -1.37474221e+00 + 4.75224286e-01 2.53153674e-01 + 4.37644006e+00 8.49116998e-01 + 2.27282959e+00 6.16568202e-01 + 1.16006880e+00 1.65832798e-01 + -1.67163193e+00 -1.22555386e+00 + -1.38231118e+00 -7.29575504e-01 + -3.49922750e+00 -2.26446675e+00 + -3.73780110e-01 -1.90657869e-01 + 1.68627679e+00 1.05662987e+00 + -3.28891792e+00 -1.11080334e+00 + -2.59815798e+00 -1.51410198e+00 + -2.61203309e+00 -6.00143552e-01 + 6.58964943e-01 4.47216094e-01 + -2.26711381e+00 -7.26512923e-01 + -5.31429009e-02 -1.97925341e-02 + 3.19749807e+00 9.20425476e-01 + -1.37595787e+00 -6.58062732e-01 + 8.09900278e-01 -3.84286160e-01 + -5.07741280e+00 -1.97683808e+00 + -2.99764250e+00 -1.50753777e+00 + -9.87671815e-01 -4.63255889e-01 + 1.65390765e+00 6.73806615e-02 + 5.51252659e+00 2.69842267e+00 + -2.23724309e+00 -4.77624004e-01 + 4.99726228e+00 1.74690949e+00 + 1.75859162e-01 -1.49350995e-01 + 4.13382789e+00 1.31735161e+00 + 2.69058117e+00 4.87656923e-01 + 1.07180318e+00 1.01426954e+00 + 3.37216869e+00 1.05955377e+00 + -2.95006781e+00 -1.57048303e+00 + -2.46401648e+00 -8.37056374e-01 + 1.19012962e-01 7.54702770e-01 + 3.34142539e+00 4.81938295e-01 + 2.92643913e+00 1.04301050e+00 + 2.89697751e+00 1.37551442e+00 + -1.03094242e+00 2.20903962e-01 + -5.13914589e+00 -2.23355387e+00 + -8.81680780e-01 1.83590000e-01 + 2.82334775e+00 1.26650464e+00 + -2.81042540e-01 -3.26370240e-01 + 2.97995487e+00 8.34569452e-01 + -1.39857135e+00 -1.15798385e+00 + 4.27186506e+00 9.04253702e-01 + 6.98684517e-01 7.91167305e-01 + 3.52233095e+00 1.29976473e+00 + 2.21448029e+00 2.73213379e-01 + -3.13505683e-01 -1.20593774e-01 + 3.70571571e+00 1.06220876e+00 + 9.83881041e-01 5.67713803e-01 + -2.17897705e+00 2.52925205e-01 + 1.38734039e+00 4.61287066e-01 + -1.41181602e+00 -1.67248955e-02 + -1.69974639e+00 -7.17812071e-01 + -2.01005793e-01 -7.49662056e-01 + 1.69016336e+00 3.24687979e-01 + -2.03250179e+00 -2.76108460e-01 + 3.68776848e-01 4.12536941e-01 + 7.66238259e-01 -1.84750637e-01 + -2.73989147e-01 -1.72817250e-01 + -2.18623745e+00 -2.10906798e-01 + -1.39795625e-01 3.26066094e-02 + -2.73826912e-01 -6.67586097e-02 + -1.57880654e+00 -4.99395900e-01 + 4.55950908e+00 2.29410489e+00 + -7.36479631e-01 -1.57861857e-01 + 1.92082888e+00 1.05843391e+00 + 4.29192810e+00 1.38127810e+00 + 1.61852879e+00 1.95871986e-01 + -1.95027403e+00 -5.22448168e-01 + -1.67446281e+00 -9.41497162e-01 + 6.07097859e-01 3.44178029e-01 + -3.44004683e+00 -1.49258461e+00 + 2.72114752e+00 6.00728991e-01 + 8.80685522e-01 -2.53243336e-01 + 1.39254928e+00 3.42988512e-01 + 1.14194836e-01 -8.57945694e-02 + -1.49387332e+00 -7.60860481e-01 + -1.98053285e+00 -4.86039865e-01 + 3.56008568e+00 1.08438692e+00 + 2.27833961e-01 1.09441881e+00 + -1.16716710e+00 -6.54778242e-01 + 2.02156613e+00 5.42075758e-01 + 1.08429178e+00 -7.67420693e-01 + 6.63058455e-01 4.61680991e-01 + -1.06201537e+00 1.38862846e-01 + 3.08701875e+00 8.32580273e-01 + -4.96558108e-01 -2.47031257e-01 + 7.95109987e-01 7.59314147e-02 + -3.39903524e-01 8.71565566e-03 + 8.68351357e-01 4.78358641e-01 + 1.48750819e+00 7.63257420e-01 + -4.51224101e-01 -4.44056898e-01 + -3.02734750e-01 -2.98487961e-01 + 5.46846609e-01 7.02377629e-01 + 1.65129778e+00 3.74008231e-01 + -7.43336512e-01 3.95723531e-01 + -5.88446605e-01 -6.47520211e-01 + 3.58613167e+00 1.95024937e+00 + 3.11718883e+00 8.37984715e-01 + 1.80919244e+00 9.62644986e-01 + 5.43856371e-02 -5.86297543e-01 + -1.95186766e+00 -1.02624212e-01 + 8.95628057e-01 5.91812281e-01 + 4.97691627e-02 5.31137156e-01 + -1.07633113e+00 -2.47392788e-01 + -1.17257986e+00 -8.68528265e-01 + -8.19227665e-02 5.80579434e-03 + -2.86409787e-01 1.95812924e-01 + 1.10582671e+00 7.42853240e-01 + 4.06429774e+00 1.06557476e+00 + -3.42521792e+00 -7.74327139e-01 + 1.28468671e+00 6.20431661e-01 + 6.01201008e-01 -1.16799728e-01 + -1.85058727e-01 -3.76235293e-01 + 5.44083324e+00 2.98490868e+00 + 2.69273070e+00 7.83901153e-01 + 1.88938036e-01 -4.83222152e-01 + 1.05667256e+00 -2.57003165e-01 + 2.99711662e-01 -4.33131912e-01 + 7.73689216e-02 -1.78738364e-01 + 9.58326279e-01 6.38325706e-01 + -3.97727049e-01 2.27314759e-01 + 3.36098175e+00 1.12165237e+00 + 1.77804871e+00 6.46961933e-01 + -2.86945546e+00 -1.00395518e+00 + 3.03494815e+00 7.51814612e-01 + -1.43658194e+00 -3.55432244e-01 + -3.08455105e+00 -1.51535106e+00 + -1.55841975e+00 3.93454820e-02 + 7.96073412e-01 -3.11036969e-01 + -9.84125401e-01 -1.02064649e+00 + -7.75688143e+00 -3.65219926e+00 + 1.53816429e+00 7.65926670e-01 + -4.92712738e-01 2.32244240e-02 + -1.93166919e+00 -1.07701304e+00 + 2.03029875e-02 -7.54055699e-01 + 2.52177489e+00 1.01544979e+00 + 3.65109048e-01 -9.48328494e-01 + -1.28849143e-01 2.51947174e-01 + -1.02428075e+00 -9.37767116e-01 + -3.04179748e+00 -9.97926994e-01 + -2.51986980e+00 -1.69117413e+00 + -1.24900838e+00 -4.16179917e-01 + 2.77943992e+00 1.22842327e+00 + -4.37434557e+00 -1.70182693e+00 + -1.60019319e+00 -4.18345639e-01 + -1.67613646e+00 -9.44087262e-01 + -9.00843245e-01 8.26378089e-02 + 3.29770621e-01 -9.07870444e-01 + -2.84650535e+00 -9.00155396e-01 + 1.57111705e+00 7.07432268e-01 + 1.24948552e+00 1.04812849e-01 + 1.81440558e+00 9.53545082e-01 + -1.74915794e+00 -1.04606288e+00 + 1.20593269e+00 -1.12607147e-02 + 1.36004919e-01 -1.09828044e+00 + 2.57480693e-01 3.34941541e-01 + 7.78775385e-01 -5.32494732e-01 + -1.79155126e+00 -6.29994129e-01 + -1.75706839e+00 -8.35100126e-01 + 4.29512012e-01 7.81426910e-02 + 3.08349370e-01 -1.27359861e-01 + 1.05560329e+00 4.55150640e-01 + 1.95662574e+00 1.17593217e+00 + 8.77376632e-01 6.57866662e-01 + 7.71311255e-01 9.15134334e-02 + -6.36978275e+00 -2.55874241e+00 + -2.98335339e+00 -1.59567024e+00 + -3.67104587e-01 1.85315291e-01 + 1.95347407e+00 -7.15503113e-02 + 8.45556363e-01 6.51256415e-02 + 9.42868521e-01 3.56647624e-01 + 2.99321875e+00 1.07505254e+00 + -2.91030538e-01 -3.77637183e-01 + 1.62870918e+00 3.37563671e-01 + 2.05773173e-01 3.43337416e-01 + -8.40879199e-01 -1.35600767e-01 + 1.38101624e+00 5.99253495e-01 + -6.93715607e+00 -2.63580662e+00 + -1.04423404e+00 -8.32865050e-01 + 1.33448476e+00 1.04863475e+00 + 6.01675207e-01 1.98585194e-01 + 2.31233993e+00 7.98628331e-01 + 1.85201313e-01 -1.76070247e+00 + 1.92006354e+00 8.45737582e-01 + 1.06320415e+00 2.93426068e-01 + -1.20360141e+00 -1.00301288e+00 + 1.95926629e+00 6.26643532e-01 + 6.04483978e-02 5.72643059e-01 + -1.04568563e+00 -5.91021496e-01 + 2.62300678e+00 9.50997831e-01 + -4.04610275e-01 3.73150879e-01 + 2.26371902e+00 8.73627529e-01 + 2.12545313e+00 7.90640352e-01 + 7.72181917e-03 1.65718952e-02 + 1.00422340e-01 -2.05562936e-01 + -1.22989802e+00 -1.01841681e-01 + 3.09064082e+00 1.04288010e+00 + 5.18274167e+00 1.34749259e+00 + -8.32075153e-01 -1.97592029e-01 + 3.84126764e-02 5.58171345e-01 + 4.99560727e-01 -4.26154438e-02 + 4.79071151e+00 2.19728942e+00 + -2.78437968e+00 -1.17812590e+00 + -2.22804226e+00 -4.31174255e-01 + 8.50762292e-01 -1.06445261e-01 + 1.10812830e+00 -2.59118812e-01 + -2.91450155e-01 6.42802679e-01 + -1.38631532e-01 -5.88585623e-01 + -5.04120983e-01 -2.17094915e-01 + 3.41410820e+00 1.67897767e+00 + -2.23697326e+00 -6.62735244e-01 + -3.55961064e-01 -1.27647226e-01 + -3.55568274e+00 -2.49011369e+00 + -8.77586408e-01 -9.38268065e-03 + 1.52382384e-01 -5.62155760e-01 + 1.55885574e-01 1.07617069e-01 + -8.37129973e-01 -5.22259081e-01 + -2.92741750e+00 -1.35049428e+00 + -3.54670781e-01 5.69205952e-02 + 2.21030255e+00 1.34689986e+00 + 1.60787722e+00 5.75984706e-01 + 1.32294221e+00 5.31577509e-01 + 7.05672928e-01 3.34241244e-01 + 1.41406179e+00 1.15783408e+00 + -6.92172228e-01 -2.84817896e-01 + 3.28358655e-01 -2.66910083e-01 + 1.68013644e-01 -4.28016549e-02 + 2.07365974e+00 7.76496211e-01 + -3.92974907e-01 2.46796730e-01 + -5.76078636e-01 3.25676963e-01 + -1.82547204e-01 -5.06410543e-01 + 3.04754906e+00 1.16174496e+00 + -3.01090632e+00 -1.09195183e+00 + -1.44659696e+00 -6.87838682e-01 + 2.11395861e+00 9.10495785e-01 + 1.40962871e+00 1.13568678e+00 + -1.66653234e-01 -2.10012503e-01 + 3.17456029e+00 9.74502922e-01 + 2.15944820e+00 8.62807189e-01 + -3.45418719e+00 -1.33647548e+00 + -3.41357732e+00 -8.47048920e-01 + -3.06702448e-01 -6.64280634e-01 + -2.86930714e-01 -1.35268264e-01 + -3.15835557e+00 -5.43439253e-01 + 2.49541440e-01 -4.71733570e-01 + 2.71933912e+00 4.13308399e-01 + -2.43787038e+00 -1.08050547e+00 + -4.90234490e-01 -6.64069865e-01 + 8.99524451e-02 5.76180541e-01 + 5.00500404e+00 2.12125521e+00 + -1.73107940e-01 -2.28506575e-02 + 5.44938858e-01 -1.29523352e-01 + 5.13526842e+00 1.68785993e+00 + 1.70228304e+00 1.02601138e+00 + 3.58957507e+00 1.54396196e+00 + 1.85615738e+00 4.92916197e-01 + 2.55772147e+00 7.88438908e-01 + -1.57008279e+00 -4.17377300e-01 + -1.42548604e+00 -3.63684860e-01 + -8.52026118e-01 2.72052686e-01 + -5.10563077e+00 -2.35665994e+00 + -2.95517031e+00 -1.84945297e+00 + -2.91947959e+00 -1.66016784e+00 + -4.21462387e+00 -1.41131535e+00 + 6.59901121e-01 4.87156314e-01 + -9.75352532e-01 -4.50231285e-01 + -5.94084444e-01 -1.16922670e+00 + 7.50554615e-01 -9.83692552e-01 + 1.07054926e+00 2.77143030e-01 + -3.88079578e-01 -4.17737309e-02 + -9.59373733e-01 -8.85454886e-01 + -7.53560665e-02 -5.16223870e-02 + 9.84108158e-01 -5.89290700e-02 + 1.87272961e-01 -4.34238391e-01 + 6.86509981e-01 -3.15116460e-01 + -1.07762538e+00 6.58984161e-02 + 6.09266592e-01 6.91808473e-02 + -8.30529954e-01 -7.00454791e-01 + -9.13179464e-01 -6.31712891e-01 + 7.68744851e-01 1.09840676e+00 + -1.07606690e+00 -8.78390282e-01 + -1.71038184e+00 -5.73606033e-01 + 8.75982765e-01 3.66343143e-01 + -7.04919009e-01 -8.49182590e-01 + -1.00274668e+00 -7.99573611e-01 + -1.05562848e+00 -5.84060076e-01 + 4.03490015e+00 1.28679206e+00 + -3.53484804e+00 -1.71381255e+00 + 2.31527363e-01 1.04179397e-01 + -3.58592392e-02 3.74895739e-01 + 3.92253428e+00 1.81852726e+00 + -7.27384249e-01 -6.45605128e-01 + 4.65678097e+00 2.41379899e+00 + 1.16750534e+00 7.60718205e-01 + 1.15677059e+00 7.96225550e-01 + -1.42920261e+00 -4.66946295e-01 + 3.71148192e+00 1.88060191e+00 + 2.44052407e+00 3.84472199e-01 + -1.64535035e+00 -8.94530036e-01 + -3.69608753e+00 -1.36402754e+00 + 2.24419208e+00 9.69744889e-01 + 2.54822427e+00 1.22613039e+00 + 3.77484909e-01 -5.98521878e-01 + -3.61521175e+00 -1.11123912e+00 + 3.28113127e+00 1.52551775e+00 + -3.51030902e+00 -1.53913980e+00 + -2.44874505e+00 -6.30246005e-01 + -3.42516153e-01 -5.07352665e-01 + 1.09110502e+00 6.36821628e-01 + -2.49434967e+00 -8.02827146e-01 + 1.41763139e+00 -3.46591820e-01 + 1.61108619e+00 5.93871102e-01 + 3.97371717e+00 1.35552499e+00 + -1.33437177e+00 -2.83908670e-01 + -1.41606483e+00 -1.76402601e-01 + 2.23945322e-01 -1.77157065e-01 + 2.60271569e+00 2.40778251e-01 + -2.82213895e-02 1.98255474e-01 + 4.20727940e+00 1.31490863e+00 + 3.36944889e+00 1.57566635e+00 + 3.53049396e+00 1.73579350e+00 + -1.29170202e+00 -1.64196290e+00 + 9.27295604e-01 9.98808036e-01 + 1.75321843e-01 -2.83267817e-01 + -2.19069578e+00 -1.12814358e+00 + 1.66606031e+00 7.68006933e-01 + -7.13826035e-01 5.20881684e-02 + -3.43821888e+00 -2.36137021e+00 + -5.93210310e-01 1.21843813e-01 + -4.09800822e+00 -1.39893953e+00 + 2.74110954e+00 1.52728606e+00 + 1.72652512e+00 -1.25435113e-01 + 1.97722357e+00 6.40667481e-01 + 4.18635780e-01 3.57018509e-01 + -1.78303569e+00 -2.11864764e-01 + -3.52809366e+00 -2.58794450e-01 + -4.72407090e+00 -1.63870734e+00 + 1.73917807e+00 8.73251829e-01 + 4.37979356e-01 8.49210569e-01 + 3.93791881e+00 1.76269490e+00 + 2.79065411e+00 1.04019042e+00 + -8.47426142e-01 -3.40136892e-01 + -4.24389181e+00 -1.80253120e+00 + -1.86675870e+00 -7.64558265e-01 + 9.46212675e-01 -7.77681445e-02 + -2.82448462e+00 -1.33592449e+00 + -2.57938567e+00 -1.56554690e+00 + -2.71615767e+00 -6.27667233e-01 + -1.55999166e+00 -5.81013466e-01 + -4.24696864e-01 -7.44673250e-01 + 1.67592970e+00 7.68164292e-01 + 8.48455216e-01 -6.05681126e-01 + 6.12575454e+00 1.65607584e+00 + 1.38207327e+00 2.39261863e-01 + 3.13364450e+00 1.17154698e+00 + 1.71694858e+00 1.26744905e+00 + -1.61746367e+00 -8.80098073e-01 + -8.52196756e-01 -9.27299728e-01 + -1.51562462e-01 -8.36552490e-02 + -7.04792753e-01 -1.24726713e-02 + -3.35265757e+00 -1.82176312e+00 + 3.32173170e-01 -1.33405580e-01 + 4.95841013e-01 4.58292712e-01 + 1.57713955e+00 7.79272991e-01 + 2.09743109e+00 9.23542557e-01 + 3.90450311e-03 -8.42873164e-01 + 2.59519038e+00 7.56479591e-01 + -5.77643976e-01 -2.36401904e-01 + -5.22310654e-01 1.34187830e-01 + -2.22096086e+00 -7.75507719e-01 + 1.35907831e+00 7.80197510e-01 + 3.80355868e+00 1.16983476e+00 + 3.82746596e+00 1.31417718e+00 + 3.30451183e+00 1.55398159e+00 + -3.42917814e-01 -8.62281222e-02 + -2.59093020e+00 -9.29883526e-01 + 1.40928562e+00 1.08398346e+00 + 1.54400137e-01 3.35881092e-01 + 1.59171586e+00 1.18855802e+00 + -5.25164002e-01 -1.03104220e-01 + 2.20067959e+00 1.37074713e+00 + 6.97860830e-01 6.27718548e-01 + -4.59743507e-01 1.36061163e-01 + -1.04691963e-01 -2.16271727e-01 + -1.08905573e+00 -5.95510769e-01 + -1.00826983e+00 -5.38509162e-02 + -3.16402719e+00 -1.33414216e+00 + 1.47870874e-01 1.75234619e-01 + -2.57078234e-01 7.03316889e-02 + 1.81073945e+00 4.26901462e-01 + 2.65476530e+00 6.74217273e-01 + 1.27539811e+00 6.22914081e-01 + -3.76750499e-01 -1.20629449e+00 + 1.00177595e+00 -1.40660091e-01 + -2.98919265e+00 -1.65145013e+00 + -2.21557682e+00 -8.11123452e-01 + -3.22635378e+00 -1.65639056e+00 + -2.72868553e+00 -1.02812087e+00 + 1.26042797e+00 8.49005248e-01 + -9.38318534e-01 -9.87588651e-01 + 3.38013194e-01 -1.00237461e-01 + 1.91175691e+00 8.48716369e-01 + 4.30244344e-01 6.05539915e-02 + 2.21783435e+00 3.03268204e-01 + 1.78019576e+00 1.27377108e+00 + 1.59733274e+00 4.40674687e-02 + 3.97428484e+00 2.20881566e+00 + -2.41108677e+00 -6.01410418e-01 + -2.50796499e+00 -5.71169866e-01 + -3.71957427e+00 -1.38195726e+00 + -1.57992670e+00 1.32068593e-01 + -1.35278851e+00 -6.39349270e-01 + 1.23075932e+00 2.40445409e-01 + 1.35606530e+00 4.33180078e-01 + 9.60968518e-02 2.26734255e-01 + 6.22975063e-01 5.03431915e-02 + -1.47624851e+00 -3.60568238e-01 + -2.49337808e+00 -1.15083052e+00 + 2.15717792e+00 1.03071559e+00 + -3.07814376e-02 1.38700314e-02 + 4.52049499e-02 -4.86409775e-01 + 2.58231061e+00 1.14327809e-01 + 1.10999138e+00 -5.18568405e-01 + -2.19426443e-01 -5.37505538e-01 + -4.44740298e-01 6.78099955e-01 + 4.03379080e+00 1.49825720e+00 + -5.13182408e-01 -4.90201950e-01 + -6.90139716e-01 1.63875126e-01 + -8.17281461e-01 2.32155064e-01 + -2.92357619e-01 -8.02573544e-01 + -1.80769841e+00 -7.58907326e-01 + 2.16981590e+00 1.06728873e+00 + 1.98995203e-01 -6.84176682e-02 + -2.39546753e+00 -2.92873789e-01 + -4.24251021e+00 -1.46255564e+00 + -5.01411291e-01 -5.95712813e-03 + 2.68085809e+00 1.42883780e+00 + -4.13289873e+00 -1.62729388e+00 + 1.87957843e+00 3.63341638e-01 + -1.15270744e+00 -3.03563774e-01 + -4.43994248e+00 -2.97323905e+00 + -7.17067733e-01 -7.08349542e-01 + -3.28870393e+00 -1.19263863e+00 + -7.55325944e-01 -5.12703329e-01 + -2.07291938e+00 -2.65025085e-01 + -7.50073814e-01 -1.70771041e-01 + -8.77381404e-01 -5.47417325e-01 + -5.33725862e-01 5.15837119e-01 + 8.45056431e-01 2.82125560e-01 + -1.59598637e+00 -1.38743235e+00 + 1.41362902e+00 1.06407789e+00 + 1.02584504e+00 -3.68219466e-01 + -1.04644488e+00 -1.48769392e-01 + 2.66990191e+00 8.57633492e-01 + -1.84251857e+00 -9.82430175e-01 + 9.71404204e-01 -2.81934209e-01 + -2.50177989e+00 -9.21260335e-01 + -1.31060074e+00 -5.84488113e-01 + -2.12129400e-01 -3.06244708e-02 + -5.28933882e+00 -2.50663129e+00 + 1.90220541e+00 1.08662918e+00 + -3.99366086e-02 -6.87178973e-01 + -4.93417342e-01 4.37354182e-01 + 2.13494486e+00 1.37679569e+00 + 2.18396765e+00 5.81023868e-01 + -3.07866587e+00 -1.45384974e+00 + 6.10894119e-01 -4.17050124e-01 + -1.88766952e+00 -8.86160058e-01 + 3.34527253e+00 1.78571260e+00 + 6.87769059e-01 -5.01157336e-01 + 2.60470837e+00 1.45853560e+00 + -6.49315691e-01 -9.16112805e-01 + -1.29817687e+00 -2.15924339e-01 + -1.20100409e-03 -4.03137422e-01 + -1.36471594e+00 -6.93266356e-01 + 1.38682062e+00 7.15131598e-01 + 2.47830103e+00 1.24862305e+00 + -2.78288147e+00 -1.03329235e+00 + -7.33443403e-01 -6.11041652e-01 + -4.12745671e-01 -5.96133390e-02 + -2.58632336e+00 -4.51557058e-01 + -1.16570367e+00 -1.27065510e+00 + 2.76187104e+00 2.21895451e-01 + -3.80443767e+00 -1.66319902e+00 + 9.84658633e-01 6.81475569e-01 + 9.33814584e-01 -4.89335563e-02 + -4.63427997e-01 1.72989539e-01 + 1.82401546e+00 3.60164021e-01 + -5.36521077e-01 -8.08691351e-01 + -1.37367030e+00 -1.02126160e+00 + -3.70310682e+00 -1.19840844e+00 + -1.51894242e+00 -3.89510223e-01 + -3.67347940e-01 -3.25540516e-02 + -1.00988595e+00 1.82802194e-01 + 2.01622795e+00 7.86367901e-01 + 1.02440231e+00 8.79780360e-01 + -3.05971480e+00 -8.40901527e-01 + 2.73909457e+00 1.20558628e+00 + 2.39559056e+00 1.10786694e+00 + 1.65471544e+00 7.33824651e-01 + 2.18546787e+00 6.41168955e-01 + 1.47152266e+00 3.91839132e-01 + 1.45811155e+00 5.21820495e-01 + -4.27531469e-02 -3.52343068e-03 + -9.54948010e-01 -1.52313876e-01 + 7.57151215e-01 -5.68728854e-03 + -8.46205751e-01 -7.54580229e-01 + 4.14493548e+00 1.45532780e+00 + 4.58688968e-01 -4.54012803e-02 + -1.49295381e+00 -4.57471758e-01 + 1.80020351e+00 8.13724973e-01 + -5.82727738e+00 -2.18269581e+00 + -2.09017809e+00 -1.18305177e+00 + -2.31628303e+00 -7.21600235e-01 + -8.09679091e-01 -1.49101752e-01 + 8.88005605e-01 8.57940857e-01 + -1.44148219e+00 -3.10926299e-01 + 3.68828186e-01 -3.08848059e-01 + -6.63267389e-01 -8.58950139e-02 + -1.14702569e+00 -6.32147854e-01 + -1.51741715e+00 -8.53330564e-01 + -1.33903718e+00 -1.45875547e-01 + 4.12485387e+00 1.85620435e+00 + -2.42353639e+00 -2.92669850e-01 + 1.88708583e+00 9.35984730e-01 + 2.15585179e+00 6.30469051e-01 + -1.13627973e-01 -1.62554045e-01 + 2.04540494e+00 1.36599834e+00 + 2.81591381e+00 1.60897941e+00 + 3.02736260e-02 3.83255815e-03 + 7.97634013e-02 -2.82035099e-01 + -3.24607473e-01 -5.30065956e-01 + -3.91862894e+00 -1.94083334e+00 + 1.56360901e+00 7.93882743e-01 + -1.03905772e+00 6.25590229e-01 + 2.54746492e+00 1.64233560e+00 + -4.80774423e-01 -8.92298032e-02 + 9.06979990e-02 1.05020427e+00 + -2.47521290e+00 -1.78275982e-01 + -3.91871729e-01 3.80285423e-01 + 1.00658382e+00 4.58947483e-01 + 4.68102941e-01 1.02992741e+00 + 4.44242568e-01 2.89870239e-01 + 3.29684452e+00 1.44677474e+00 + -2.24983007e+00 -9.65574499e-01 + -3.54453926e-01 -3.99020325e-01 + -3.87429665e+00 -1.90079739e+00 + 2.02656674e+00 1.12444894e+00 + 3.77011621e+00 1.43200852e+00 + 1.61259275e+00 4.65417399e-01 + 2.28725434e+00 6.79181395e-01 + 2.75421009e+00 2.27327345e+00 + -2.40894409e+00 -1.03926359e+00 + 1.52996651e-01 -2.73373046e-02 + -2.63218977e+00 -7.22802821e-01 + 2.77688169e+00 1.15310186e+00 + 1.18832341e+00 4.73457165e-01 + -2.35536326e+00 -1.08034554e+00 + -5.84221627e-01 1.03505984e-02 + 2.96730300e+00 1.33478306e+00 + -8.61947692e-01 6.09137051e-02 + 8.22343921e-01 -8.14155286e-02 + 1.75809015e+00 1.07921470e+00 + 1.19501279e+00 1.05309972e+00 + -1.75901792e+00 9.75320161e-02 + 1.64398635e+00 9.54384323e-01 + -2.21878052e-01 -3.64847144e-01 + -2.03128968e+00 -8.57866419e-01 + 1.86750633e+00 7.08524487e-01 + 8.03972976e-01 3.47404314e-01 + 3.41203749e+00 1.39810900e+00 + 4.22397681e-01 -6.41440488e-01 + -4.88493360e+00 -1.58967816e+00 + -1.67649284e-01 -1.08485915e-01 + 2.11489023e+00 1.50506158e+00 + -1.81639929e+00 -3.85542192e-01 + 2.24044819e-01 -1.45100577e-01 + -3.39262411e+00 -1.44394324e+00 + 1.68706599e+00 2.29199618e-01 + -1.94093257e+00 -1.65975814e-01 + 8.28143367e-01 5.92109281e-01 + -8.29587998e-01 -9.57130831e-01 + -1.50011401e+00 -8.36802092e-01 + 2.40770449e+00 9.32820177e-01 + 7.41391309e-02 3.12878473e-01 + 1.87745264e-01 6.19231425e-01 + 9.57622692e-01 -2.20640033e-01 + 3.18479243e+00 1.02986233e+00 + 2.43133846e+00 8.41302677e-01 + -7.09963834e-01 1.99718943e-01 + -2.88253498e-01 -3.62772094e-01 + 5.14052574e+00 1.79304595e+00 + -3.27930993e+00 -1.29177973e+00 + -1.16723536e+00 1.29519656e-01 + 1.04801056e+00 3.41508300e-01 + -3.99256195e+00 -2.51176471e+00 + -7.62824318e-01 -6.84242153e-01 + 2.71524986e-02 5.35157164e-02 + 3.26430102e+00 1.34887262e+00 + -1.72357766e+00 -4.94524388e-01 + -3.81149536e+00 -1.28121944e+00 + 3.36919354e+00 1.10672075e+00 + -3.14841757e+00 -7.10713767e-01 + -3.16463676e+00 -7.58558435e-01 + -2.44745969e+00 -1.08816514e+00 + 2.79173264e-01 -2.19652051e-02 + 4.15309883e-01 6.07502790e-01 + -9.51007417e-01 -5.83976336e-01 + -1.47929839e+00 -8.39850409e-01 + 2.38335703e+00 6.16055149e-01 + -7.47749031e-01 -5.56164928e-01 + -3.65643622e-01 -5.06684411e-01 + -1.76634163e+00 -7.86382097e-01 + 6.76372222e-01 -3.06592181e-01 + -1.33505058e+00 -1.18301441e-01 + 3.59660179e+00 2.00424178e+00 + -7.88912762e-02 8.71956146e-02 + 1.22656397e+00 1.18149583e+00 + 4.24919729e+00 1.20082355e+00 + 2.94607456e+00 1.00676505e+00 + 7.46061275e-02 4.41761753e-02 + -2.47738025e-02 1.92737701e-01 + -2.20509316e-01 -3.79163193e-01 + -3.50222190e-01 3.58727299e-01 + -3.64788014e+00 -1.36107312e+00 + 3.56062799e+00 9.27032742e-01 + 1.04317289e+00 6.08035970e-01 + 4.06718718e-01 3.00628051e-01 + 4.33158086e+00 2.25860714e+00 + 2.13917145e-01 -1.72757967e-01 + -1.40637998e+00 -1.14119465e+00 + 3.61554872e+00 1.87797348e+00 + 1.01726871e+00 5.70255097e-01 + -7.04902551e-01 2.16444147e-01 + -2.51492186e+00 -8.52997369e-01 + 1.85097530e+00 1.15124496e+00 + -8.67569714e-01 -3.05682432e-01 + 8.07550858e-01 5.88901608e-01 + 1.85186755e-01 -1.94589367e-01 + -1.23378238e+00 -7.84128347e-01 + -1.22713161e+00 -4.21218235e-01 + 2.97751165e-01 2.81055275e-01 + 4.77703554e+00 1.66265524e+00 + 2.51549669e+00 7.49980674e-01 + 2.76510822e-01 1.40456909e-01 + 1.98740905e+00 -1.79608212e-01 + 9.35429145e-01 8.44344180e-01 + -1.20854492e+00 -5.00598453e-01 + 2.29936219e+00 8.10236668e-01 + 6.92555544e-01 -2.65891331e-01 + -1.58050994e+00 2.31237821e-01 + -1.50864880e+00 -9.49661690e-01 + -1.27689206e+00 -7.18260016e-01 + -3.12517127e+00 -1.75587113e+00 + 8.16062912e-02 -6.56551804e-01 + -5.02479939e-01 -4.67162543e-01 + -5.47435788e+00 -2.47799576e+00 + 1.95872901e-02 5.80874076e-01 + -1.59064958e+00 -6.34554756e-01 + -3.77521478e+00 -1.74301790e+00 + 5.89628224e-01 8.55736553e-01 + -1.81903543e+00 -7.50011008e-01 + 1.38557775e+00 3.71490991e-01 + 9.70032652e-01 -7.11356016e-01 + 2.63539625e-01 -4.20994771e-01 + 2.12154222e+00 8.19081400e-01 + -6.56977937e-01 -1.37810098e-01 + 8.91309581e-01 2.77864361e-01 + -7.43693195e-01 -1.46293770e-01 + 2.24447769e+00 4.00911438e-01 + -2.25169262e-01 2.04148801e-02 + 1.68744684e+00 9.47573007e-01 + 2.73086373e-01 3.30877195e-01 + 5.54294414e+00 2.14198009e+00 + -8.49238733e-01 3.65603298e-02 + 2.39685712e+00 1.17951039e+00 + -2.58230528e+00 -5.52116673e-01 + 2.79785277e+00 2.88833717e-01 + -1.96576188e-01 1.11652123e+00 + -4.69383301e-01 1.96496282e-01 + -1.95011845e+00 -6.15235169e-01 + 1.03379890e-02 2.33701239e-01 + 4.18933607e-01 2.77939814e-01 + -1.18473337e+00 -4.10051126e-01 + -7.61499744e-01 -1.43658094e+00 + -1.65586092e+00 -3.41615303e-01 + -5.58523700e-02 -5.21837080e-01 + -2.40331088e+00 -2.64521583e-01 + 2.24925206e+00 6.79843335e-02 + 1.46360479e+00 1.04271443e+00 + -3.09255443e+00 -1.82548953e+00 + 2.11325841e+00 1.14996627e+00 + -8.70657797e-01 1.02461839e-01 + -5.71056521e-01 9.71232588e-02 + -3.37870752e+00 -1.54091877e+00 + 1.03907189e+00 -1.35661392e-01 + 8.40057486e-01 6.12172413e-02 + -1.30998234e+00 -1.34077226e+00 + 7.53744974e-01 1.49447350e-01 + 9.13995056e-01 -1.81227962e-01 + 2.28386229e-01 3.74498520e-01 + 2.54829151e-01 -2.88802704e-01 + 1.61709009e+00 2.09319193e-01 + -1.12579380e+00 -5.95955338e-01 + -2.69610726e+00 -2.76222736e-01 + -2.63773329e+00 -7.84491970e-01 + -2.62167427e+00 -1.54792874e+00 + -4.80639856e-01 -1.30582102e-01 + -1.26130891e+00 -8.86841840e-01 + -1.24951950e+00 -1.18182622e+00 + -1.40107574e+00 -9.13695575e-01 + 4.99872179e-01 4.69014702e-01 + -2.03550193e-02 -1.48859738e-01 + -1.50189069e+00 -2.97714278e-02 + -2.07846113e+00 -7.29937809e-01 + -5.50576792e-01 -7.03151525e-01 + -3.88069238e+00 -1.63215295e+00 + 2.97032988e+00 6.43571144e-01 + -1.85999273e-01 1.18107620e+00 + 1.79249709e+00 6.65356160e-01 + 2.68842472e+00 1.35703255e+00 + 1.07675417e+00 1.39845588e-01 + 8.01226349e-01 2.11392275e-01 + 9.64329379e-01 3.96146195e-01 + -8.22529511e-01 1.96080831e-01 + 1.92481841e+00 4.62985744e-01 + 3.69756927e-01 3.77135799e-01 + 1.19807835e+00 8.87715050e-01 + -1.01363587e+00 -2.48151636e-01 + 8.53071010e-01 4.96887868e-01 + -3.41120553e+00 -1.35401843e+00 + -2.64787381e+00 -1.08690563e+00 + -1.11416759e+00 -4.43848915e-01 + 1.46242648e+00 6.17106076e-02 + -7.52968881e-01 -9.20972209e-01 + -1.22492228e+00 -5.40327617e-01 + 1.08001827e+00 5.29593785e-01 + -2.58706464e-01 1.13022085e-01 + -4.27394011e-01 1.17864354e-02 + -3.20728413e+00 -1.71224737e-01 + 1.71398530e+00 8.68885893e-01 + 2.12067866e+00 1.45092772e+00 + 4.32782616e-01 -3.34117769e-01 + 7.80084374e-01 -1.35100217e-01 + -2.05547729e+00 -4.70217750e-01 + 2.38379736e+00 1.09186058e+00 + -2.80825477e+00 -1.03320187e+00 + 2.63434576e+00 1.15671733e+00 + -1.60936214e+00 1.91843035e-01 + -5.02298769e+00 -2.32820708e+00 + 1.90349195e+00 1.45215416e+00 + 3.00232888e-01 3.24412586e-01 + -2.46503943e+00 -1.19550010e+00 + 1.06304233e+00 2.20136246e-01 + -2.99101388e+00 -1.58299318e+00 + 2.30071719e+00 1.12881362e+00 + -2.37587247e+00 -8.08298336e-01 + 7.27006308e-01 3.80828984e-01 + 2.61199061e+00 1.56473491e+00 + 8.33936357e-01 -1.42189425e-01 + 3.13291605e+00 1.77771210e+00 + 2.21917371e+00 5.68427075e-01 + 2.38867649e+00 9.06637262e-01 + -6.92959466e+00 -3.57682881e+00 + 2.57904824e+00 5.93959108e-01 + 2.71452670e+00 1.34436199e+00 + 4.39988761e+00 2.13124672e+00 + 5.71783077e-01 5.08346173e-01 + -3.65399429e+00 -1.18192861e+00 + 4.46176453e-01 3.75685594e-02 + -2.97501495e+00 -1.69459236e+00 + 1.60855728e+00 9.20930014e-01 + -1.44270290e+00 -1.93922306e-01 + 1.67624229e+00 1.66233866e+00 + -1.42579598e+00 -1.44990145e-01 + 1.19923176e+00 4.58490278e-01 + -9.00068460e-01 5.09701825e-02 + -1.69391694e+00 -7.60070300e-01 + -1.36576440e+00 -5.24244256e-01 + -1.03016748e+00 -3.44625878e-01 + 2.40519313e+00 1.09947587e+00 + 1.50365433e+00 1.06464802e+00 + -1.07609727e+00 -3.68897187e-01 + 2.44969069e+00 1.28486192e+00 + -1.25610307e+00 -1.14644789e+00 + 2.05962899e+00 4.31162369e-01 + -7.15886908e-01 -6.11587804e-02 + -6.92354119e-01 -7.85019920e-01 + -1.63016508e+00 -5.96944975e-01 + 1.90352536e+00 1.28197457e+00 + -4.01535243e+00 -1.81934488e+00 + -1.07534435e+00 -2.10544784e-01 + 3.25500866e-01 7.69603661e-01 + 2.18443365e+00 6.59773335e-01 + 8.80856790e-01 6.39505913e-01 + -2.23956372e-01 -4.65940132e-01 + -1.06766519e+00 -5.38388505e-03 + 7.25556863e-01 -2.91123488e-01 + -4.69451411e-01 7.89182650e-02 + 2.58146587e+00 1.29653243e+00 + 1.53747468e-01 7.69239075e-01 + -4.61152262e-01 -4.04151413e-01 + 1.48183517e+00 8.10079506e-01 + -1.83402614e+00 -1.36939322e+00 + 1.49315501e+00 7.95225425e-01 + 1.41922346e+00 1.05582774e-01 + 1.57473493e-01 9.70795657e-01 + -2.67603254e+00 -7.48562280e-01 + -8.49156216e-01 -6.05762529e-03 + 1.12944274e+00 3.67741591e-01 + 1.94228071e-01 5.28188141e-01 + -3.65610158e-01 4.05851838e-01 + -1.98839111e+00 -1.38452764e+00 + 2.73765752e+00 8.24150530e-01 + 7.63728641e-01 3.51617707e-01 + 5.78307267e+00 1.68103612e+00 + 2.27547227e+00 3.60876164e-01 + -3.50681697e+00 -1.74429984e+00 + 4.01241184e+00 1.26227829e+00 + 2.44946343e+00 9.06119057e-01 + -2.96638941e+00 -9.01532322e-01 + 1.11267643e+00 -3.43333381e-01 + -6.61868994e-01 -3.44666391e-01 + -8.34917179e-01 5.69478372e-01 + -1.91888454e+00 -3.03791075e-01 + 1.50397636e+00 8.31961240e-01 + 6.12260198e+00 2.16851807e+00 + 1.34093127e+00 8.86649385e-01 + 1.48748519e+00 8.26273697e-01 + 7.62243068e-01 2.64841396e-01 + -2.17604986e+00 -3.54219958e-01 + 2.64708640e-01 -4.38136718e-02 + 1.44725372e+00 1.18499914e-01 + -6.71259446e-01 -1.19526851e-01 + 2.40134595e-01 -8.90042323e-02 + -3.57238199e+00 -1.23166201e+00 + -3.77626645e+00 -1.19533443e+00 + -3.81101035e-01 -4.94160532e-01 + -3.02758757e+00 -1.18436066e+00 + 2.59116298e-01 1.38023047e+00 + 4.17900116e+00 1.12065959e+00 + 1.54598848e+00 2.89806755e-01 + 1.00656475e+00 1.76974511e-01 + -4.15730234e-01 -6.22681694e-01 + -6.00903565e-01 -1.43256959e-01 + -6.03652508e-01 -5.09936379e-01 + -1.94096658e+00 -9.48789544e-01 + -1.74464105e+00 -8.50491590e-01 + 1.17652544e+00 1.88118317e+00 + 2.35507776e+00 1.44000205e+00 + 2.63067924e+00 1.06692988e+00 + 2.88805386e+00 1.23924715e+00 + 8.27595008e-01 5.75364692e-01 + 3.91384216e-01 9.72781920e-02 + -1.03866816e+00 -1.37567768e+00 + -1.34777969e+00 -8.40266025e-02 + -4.12904508e+00 -1.67618340e+00 + 1.27918111e+00 3.52085961e-01 + 4.15361174e-01 6.28896189e-01 + -7.00539496e-01 4.80447955e-02 + -1.62332639e+00 -5.98236485e-01 + 1.45957300e+00 1.00305154e+00 + -3.06875603e+00 -1.25897545e+00 + -1.94708176e+00 4.85143006e-01 + 3.55744156e+00 -1.07468822e+00 + 1.21602223e+00 1.28768827e-01 + 1.89093098e+00 -4.70835659e-01 + -6.55759125e+00 2.70114082e+00 + 8.96843535e-01 -3.98115252e-01 + 4.13450429e+00 -2.32069236e+00 + 2.37764218e+00 -1.09098890e+00 + -1.11388901e+00 6.27083097e-01 + -6.34116929e-01 4.62816387e-01 + 2.90203079e+00 -1.33589143e+00 + 3.17457598e+00 -5.13575945e-01 + -1.76362299e+00 5.71820693e-01 + 1.66103362e+00 -8.99466249e-01 + -2.53947433e+00 8.40084780e-01 + 4.36631397e-01 7.24234261e-02 + -1.87589394e+00 5.08529113e-01 + 4.49563965e+00 -9.43365992e-01 + 1.78876299e+00 -1.27076149e+00 + -1.16269107e-01 -4.55078316e-01 + 1.92966079e+00 -8.05371385e-01 + 2.20632583e+00 -9.00919345e-01 + 1.52387824e+00 -4.82391996e-01 + 8.04004564e-01 -2.73650595e-01 + -7.75326067e-01 1.07469566e+00 + 1.83226282e+00 -4.52173344e-01 + 1.25079758e-01 -3.52895417e-02 + -9.90957437e-01 8.55993130e-01 + 1.71623322e+00 -7.08691667e-01 + -2.86175924e+00 6.75160955e-01 + -8.40817853e-01 -1.00361809e-01 + 1.33393000e+00 -4.65788123e-01 + 5.29394114e-01 -5.44881619e-02 + -8.07435599e-01 8.27353370e-01 + -4.33165824e+00 1.97299638e+00 + 1.26452422e+00 -8.34070486e-01 + 1.45996394e-02 2.97736043e-01 + -1.64489287e+00 6.72839598e-01 + -5.74234578e+00 3.20975117e+00 + 2.13841341e-02 3.64514015e-01 + 6.68084924e+00 -2.27464254e+00 + -3.22881590e+00 8.01879324e-01 + 3.02534313e-01 -4.56222796e-01 + -5.84520734e+00 1.95678162e+00 + 2.81515232e+00 -1.72101318e+00 + -2.39620908e-01 2.69145522e-01 + -7.41669691e-01 -2.30283281e-01 + -2.15682714e+00 3.45313021e-01 + 1.23475788e+00 -7.32276553e-01 + -1.71816113e-01 1.20419560e-02 + 1.89174235e+00 2.27435901e-01 + -3.64511114e-01 1.72260361e-02 + -3.24143860e+00 6.50125817e-01 + -2.25707409e+00 5.66970751e-01 + 1.03901456e+00 -1.00588433e+00 + -5.09159710e+00 1.58736109e+00 + 1.45534075e+00 -5.83787452e-01 + 4.28879587e+00 -1.58006866e+00 + 8.52384427e-01 -1.11042299e+00 + 4.51431615e+00 -2.63844265e+00 + -4.33042648e+00 1.86497078e+00 + -2.13568046e+00 5.82559743e-01 + -4.42568887e+00 1.26131214e+00 + 3.15821315e+00 -1.61515905e+00 + -3.14125204e+00 8.49604386e-01 + 6.54152300e-01 -2.04624711e-01 + -3.73374317e-01 9.94187820e-02 + -3.96177282e+00 1.27245623e+00 + 9.59825199e-01 -1.15547861e+00 + 3.56902055e+00 -1.46591091e+00 + 1.55433633e-02 6.93544345e-01 + 1.15684646e+00 -4.99836352e-01 + 3.11824573e+00 -4.75900506e-01 + -8.61706369e-01 -3.50774059e-01 + 9.89057391e-01 -7.16878802e-01 + -4.94787870e+00 2.09137481e+00 + 1.37777347e+00 -1.34946349e+00 + -1.13161577e+00 8.05114754e-01 + 8.12020675e-01 -1.04849421e+00 + 4.73783881e+00 -2.26718812e+00 + 8.99579366e-01 -8.89764451e-02 + 4.78524868e+00 -2.25795843e+00 + 1.75164590e+00 -1.73822209e-01 + 1.30204590e+00 -7.26724717e-01 + -7.26526403e-01 -5.23925361e-02 + 2.01255351e+00 -1.69965366e+00 + 9.87852740e-01 -4.63577220e-01 + 2.45957762e+00 -1.29278962e+00 + -3.13817948e+00 1.64433038e+00 + -1.76302159e+00 9.62784302e-01 + -1.91106331e+00 5.81460008e-01 + -3.30883001e+00 1.30378978e+00 + 5.54376450e-01 3.78814272e-01 + 1.09982111e+00 -1.47969612e+00 + -2.61300705e-02 -1.42573464e-01 + -2.22096157e+00 7.75684440e-01 + 1.70319323e+00 -2.89738444e-01 + -1.43223842e+00 6.39284281e-01 + 2.34360959e-01 -1.64379268e-01 + -2.67147991e+00 9.46548086e-01 + 1.51131425e+00 -4.91594395e-01 + -2.48446856e+00 1.01286123e+00 + 1.50534658e-01 -2.94620246e-01 + -1.66966792e+00 1.67755508e+00 + -1.50094241e+00 3.30163095e-01 + 2.27681194e+00 -1.08064317e+00 + 2.05122965e+00 -1.15165939e+00 + -4.23509309e-01 -6.56906167e-02 + 1.80084023e+00 -1.07228556e+00 + -2.65769521e+00 1.18023206e+00 + 2.02852676e+00 -8.06793574e-02 + -4.49544185e+00 2.68200163e+00 + -7.50043216e-01 1.17079331e+00 + 6.80060893e-02 3.99055351e-01 + -3.83634635e+00 1.38406887e+00 + 3.24858545e-01 -9.25273218e-02 + -2.19895100e+00 1.47819500e+00 + -3.61569522e-01 -1.03188739e-01 + 1.12180375e-01 -9.52696354e-02 + -1.31477803e+00 1.79900570e-01 + 2.39573628e+00 -6.09739269e-01 + -1.00135700e+00 6.02837296e-01 + -4.11994589e+00 2.49599192e+00 + -1.54196236e-01 -4.84921951e-01 + 5.92569908e-01 -1.87310359e-01 + 3.85407741e+00 -1.50979925e+00 + 5.17802528e+00 -2.26032607e+00 + -1.37018916e+00 1.87111822e-01 + 8.46682996e-01 -3.56676331e-01 + -1.17559949e+00 5.29057734e-02 + -5.56475671e-02 6.79049243e-02 + 1.07851745e+00 -5.14535101e-01 + -2.71622446e+00 1.00151846e+00 + -1.08477208e+00 8.81391054e-01 + 5.50755824e-01 -5.20577727e-02 + 4.70885495e+00 -2.04220397e+00 + -1.87375336e-01 -6.16962830e-02 + 3.52097100e-01 2.21163550e-01 + 7.07929984e-01 -1.75827590e-01 + -1.22149219e+00 1.83084346e-01 + 2.58247412e+00 -6.15914898e-01 + -6.01206182e-01 -2.29832987e-01 + 9.83360449e-01 -3.75870060e-01 + -3.20027685e+00 1.35467480e+00 + 1.79178978e+00 -1.38531981e+00 + -3.30376867e-01 -1.16250192e-01 + -1.89053055e+00 5.68463567e-01 + -4.20604849e+00 1.65429681e+00 + -1.01185529e+00 1.92801240e-01 + -6.18819882e-01 5.42206996e-01 + -5.08091672e+00 2.61598591e+00 + -2.62570344e+00 2.51590658e+00 + 3.05577906e+00 -1.49090609e+00 + 2.77609677e+00 -1.37681378e+00 + -7.93515301e-02 4.28072744e-01 + -2.08359471e+00 8.94334295e-01 + 2.20163801e+00 4.01127167e-02 + -1.18145785e-01 -2.06822464e-01 + -2.74788298e-01 2.96250607e-01 + 1.59613555e+00 -3.87246203e-01 + -3.82971472e-01 -3.39716093e-02 + -4.20311307e-02 3.88529510e-01 + 1.52128574e+00 -9.33138876e-01 + -9.06584458e-01 -2.75016094e-02 + 3.56216834e+00 -9.99384622e-01 + 2.11964220e+00 -9.98749118e-02 + 4.01203480e+00 -2.03032745e+00 + -1.24171557e+00 1.97596725e-01 + -1.57230455e+00 4.14126609e-01 + -1.85484741e+00 5.40041563e-01 + 1.76329831e+00 -6.95967734e-01 + -2.29439232e-01 5.08669245e-01 + -5.45124276e+00 2.26907549e+00 + -5.71364288e-02 5.04476476e-01 + 3.12468018e+00 -1.46358879e+00 + 8.20017359e-01 6.51949028e-01 + -1.33977500e+00 2.83634232e-04 + -1.83311685e+00 1.23947117e+00 + 6.31205922e-01 1.19792164e-02 + -2.21967834e+00 6.94056232e-01 + -1.41693842e+00 9.93526233e-01 + -7.58885703e-01 6.78547347e-01 + 3.60239086e+00 -1.08644935e+00 + 6.72217073e-02 3.00036011e-02 + -3.42680958e-01 -3.48049352e-01 + 1.87546079e+00 -4.78018246e-01 + 7.00485821e-01 -3.52905383e-01 + -8.54580948e-01 8.17330861e-01 + 8.19123706e-01 -5.73927281e-01 + 2.70855639e-01 -3.08940052e-01 + -1.05059952e+00 3.27873168e-01 + 1.08282999e+00 4.84559349e-02 + -7.89899220e-01 1.22291138e+00 + -2.87939816e+00 7.17403497e-01 + -2.08429452e+00 8.87409226e-01 + 1.58409232e+00 -4.74123532e-01 + 1.26882735e+00 1.59162510e-01 + -2.53782993e+00 6.18253491e-01 + -8.92757445e-01 3.35979011e-01 + 1.31867900e+00 -1.17355054e+00 + 1.14918879e-01 -5.35184038e-01 + -1.70288738e-01 5.35868087e-02 + 4.21355121e-01 5.41848690e-02 + 2.07926943e+00 -5.72538144e-01 + 4.08788970e-01 3.77655777e-01 + -3.39631381e+00 9.84216764e-01 + 2.94170163e+00 -1.83120916e+00 + -7.94798752e-01 7.39889052e-01 + 1.46555463e+00 -4.62275563e-01 + 2.57255955e+00 -1.04671434e+00 + 8.45042540e-01 -1.96952892e-01 + -3.23526646e+00 1.60049846e+00 + 3.21948565e+00 -8.88376674e-01 + 1.43005104e+00 -9.21561086e-01 + 8.82360506e-01 2.98403872e-01 + -8.91168097e-01 1.01319072e+00 + -5.13215241e-01 -2.47182649e-01 + -1.35759444e+00 7.07450608e-02 + -4.04550983e+00 2.23534867e+00 + 1.39348883e+00 3.81637747e-01 + -2.85676418e+00 1.53240862e+00 + -1.37183120e+00 6.37977425e-02 + -3.88195859e+00 1.73887145e+00 + 1.19509776e+00 -6.25013512e-01 + -2.80062734e+00 1.79840585e+00 + 1.96558429e+00 -4.70997234e-01 + 1.93111352e+00 -9.70318441e-01 + 3.57991190e+00 -1.65065116e+00 + 2.12831714e+00 -1.11531708e+00 + -3.95661018e-01 -8.54339904e-02 + -2.41630441e+00 1.65166304e+00 + 7.55412624e-01 -1.53453579e-01 + -1.77043450e+00 1.39928715e+00 + -9.32631260e-01 8.73649199e-01 + 1.53342205e+00 -8.39569765e-01 + -6.29846924e-02 1.25023084e-01 + 3.31509049e+00 -1.10733235e+00 + -2.18957109e+00 3.07376993e-01 + -2.35740747e+00 6.47437564e-01 + -2.22142438e+00 8.47318938e-01 + -6.51401147e-01 3.48398562e-01 + 2.75763095e+00 -1.21390708e+00 + 1.12550484e+00 -5.61412847e-01 + -5.65053161e-01 6.74365205e-02 + 1.68952456e+00 -6.57566096e-01 + 8.95598401e-01 3.96738993e-01 + -1.86537066e+00 9.44129208e-01 + -2.59933294e+00 2.57423247e-01 + -6.59598267e-01 1.91828851e-02 + -2.64506676e+00 8.41783205e-01 + -1.25911802e+00 5.52425066e-01 + -1.39754507e+00 3.73689222e-01 + 5.49550729e-02 1.35071215e+00 + 3.31874811e+00 -1.05682424e+00 + 3.63159604e+00 -1.42864695e+00 + -4.45944617e+00 1.42889446e+00 + 5.87314342e-01 -4.88892988e-01 + -7.26130820e-01 1.51936106e-01 + -1.79246441e+00 6.05888105e-01 + -5.50948207e-01 6.21443081e-01 + -3.17246063e-01 1.77213880e-01 + -2.00098937e+00 1.23799074e+00 + 4.33790961e+00 -1.08490465e+00 + -2.03114114e+00 1.31613237e+00 + -6.29216542e+00 1.92406317e+00 + -1.60265624e+00 8.87947500e-01 + 8.64465062e-01 -8.37416270e-01 + -2.14273937e+00 8.05485900e-01 + -2.36844256e+00 6.17915124e-01 + -1.40429636e+00 6.78296866e-01 + 9.99019988e-01 -5.84297572e-01 + 7.38824546e-01 1.68838678e-01 + 1.45681238e+00 3.04641461e-01 + 2.15914949e+00 -3.43089227e-01 + -1.23895930e+00 1.05339864e-01 + -1.23162264e+00 6.46629863e-01 + 2.28183862e+00 -9.24157063e-01 + -4.29615882e-01 5.69130863e-01 + -1.37449121e+00 -9.12032183e-01 + -7.33890904e-01 -3.91865471e-02 + 8.41400661e-01 -4.76002200e-01 + -1.73349274e-01 -6.84143467e-02 + 3.16042891e+00 -1.32651856e+00 + -3.78244609e+00 2.38619718e+00 + -3.69634380e+00 2.22368561e+00 + 1.83766344e+00 -1.65675953e+00 + -1.63206002e+00 1.19484469e+00 + 3.68480064e-01 -5.70764494e-01 + 3.61982479e-01 1.04274409e-01 + 2.48863048e+00 -1.13285542e+00 + -2.81896488e+00 9.47958768e-01 + 5.74952901e-01 -2.75959392e-01 + 3.72783275e-01 -3.48937848e-01 + 1.95935716e+00 -1.06750415e+00 + 5.19357531e+00 -2.32070803e+00 + 4.09246149e+00 -1.89976700e+00 + -3.36666087e-01 8.17645057e-02 + 1.85453493e-01 3.76913151e-01 + -3.06458262e+00 1.34106402e+00 + -3.13796566e+00 7.00485099e-01 + 1.42964058e+00 -1.35536932e-01 + -1.23440423e-01 4.60094177e-02 + -2.86753037e+00 -5.21724160e-02 + 2.67113726e+00 -1.83746924e+00 + -1.35335062e+00 1.28238073e+00 + -2.43569899e+00 1.25998539e+00 + 1.26036740e-01 -2.35416844e-01 + -1.35725745e+00 7.37788491e-01 + -3.80897538e-01 3.30757889e-01 + 6.58694434e-01 -1.07566603e+00 + 2.11273640e+00 -9.02260632e-01 + 4.00755057e-01 -2.49229150e-02 + -1.80095812e+00 9.73099742e-01 + -2.68408372e+00 1.63737364e+00 + -2.66079826e+00 7.47289412e-01 + -9.92321439e-02 -1.49331396e-01 + 4.45678251e+00 -1.80352394e+00 + 1.35962915e+00 -1.31554389e+00 + -7.76601417e-01 -9.66173523e-02 + 1.68096348e+00 -6.27235133e-01 + 1.53081227e-01 -3.54216830e-01 + -1.54913095e+00 3.43689269e-01 + 5.29187357e-02 -6.73916964e-01 + -2.06606084e+00 8.34784242e-01 + 1.73701179e+00 -6.06467340e-01 + 1.55856757e+00 -2.58642780e-01 + 1.04349101e+00 -4.43027348e-01 + -1.02397719e+00 1.01308824e+00 + -2.13860204e-01 -4.73347361e-01 + -2.59004955e+00 1.43367853e+00 + 7.98457679e-01 2.18621627e-02 + -1.32974762e+00 4.61802208e-01 + 3.21419359e-01 2.30723316e-02 + 2.87201888e-02 6.24566672e-02 + -1.22261418e+00 6.02340363e-01 + 1.28750335e+00 -3.34839548e-02 + -9.67952623e-01 4.34470505e-01 + 2.02850324e+00 -9.05160255e-01 + -4.13946010e+00 2.33779091e+00 + -4.47508806e-01 3.06440495e-01 + -3.91543394e+00 1.68251022e+00 + -6.45193001e-01 5.29781162e-01 + -2.15518916e-02 5.07278355e-01 + -2.83356868e+00 1.00670227e+00 + 1.82989749e+00 -1.37329222e+00 + -1.09330213e+00 1.08560688e+00 + 1.90533722e+00 -1.28905879e+00 + 2.33986084e+00 2.30642626e-02 + 8.01940220e-01 -1.63986962e+00 + -4.23415165e+00 2.07530423e+00 + 9.33382522e-01 -7.62917211e-01 + -1.84033954e+00 1.07469401e+00 + -2.81938669e+00 1.07342024e+00 + -7.05169988e-01 2.13124943e-01 + 5.09598137e-01 1.32725493e-01 + -2.34558226e+00 8.62383168e-01 + -1.70322072e+00 2.70893796e-01 + 1.23652660e+00 -7.53216034e-02 + 2.84660646e+00 -3.48178304e-02 + 2.50250128e+00 -1.27770855e+00 + -1.00279469e+00 8.77194218e-01 + -4.34674121e-02 -2.12091350e-01 + -5.84151289e-01 1.50382340e-01 + -1.79024013e+00 4.24972808e-01 + -1.23434666e+00 -8.85546570e-02 + 1.36575412e+00 -6.42639880e-01 + -1.98429947e+00 2.27650336e-01 + 2.36253589e+00 -1.51340773e+00 + 8.79157643e-01 6.84142159e-01 + -2.18577755e+00 2.76526200e-01 + -3.55473434e-01 8.29976561e-01 + 1.16442595e+00 -5.97699411e-01 + -7.35528097e-01 2.40318183e-01 + -1.73702631e-01 7.33788663e-02 + -1.40451745e+00 3.24899628e-01 + -2.05434385e+00 5.68123738e-01 + 8.47876642e-01 -5.74224294e-01 + -6.91955602e-01 1.26009087e+00 + 2.56574498e+00 -1.15602581e+00 + 3.93306545e+00 -1.38398209e+00 + -2.73230251e+00 4.89062581e-01 + -1.04315474e+00 6.06335547e-01 + 1.23231431e+00 -4.46675065e-01 + -3.93035285e+00 1.43287651e+00 + -1.02132111e+00 9.58919791e-01 + -1.49425352e+00 1.06456165e+00 + -6.26485337e-01 1.03791402e+00 + -6.61772998e-01 2.63275425e-01 + -1.80940386e+00 5.70767403e-01 + 9.83720450e-01 -1.39449756e-01 + -2.24619662e+00 9.01044870e-01 + 8.94343014e-01 5.31038678e-02 + 1.95518199e-01 -2.81343295e-01 + -2.30533019e-01 -1.74478106e-01 + -2.01550361e+00 5.55958010e-01 + -4.36281469e+00 1.94374226e+00 + -5.18530457e+00 2.89278357e+00 + 2.67289101e+00 -2.98511449e-01 + -1.53566179e+00 -1.00588944e-01 + -6.09943217e-02 -1.56986047e-01 + -5.22146452e+00 1.66209208e+00 + -3.69777478e+00 2.26154873e+00 + 2.24607181e-01 -4.86934960e-01 + 2.49909450e+00 -1.03033370e+00 + -1.07841120e+00 8.22388054e-01 + -3.20697089e+00 1.09536143e+00 + 3.43524232e+00 -1.47289362e+00 + -5.65784134e-01 4.60365175e-01 + -1.76714734e+00 1.57752346e-01 + -7.77620365e-01 5.60153443e-01 + 6.34399352e-01 -5.22339836e-01 + 2.91011875e+00 -9.72623380e-01 + -1.19286824e+00 6.32370253e-01 + -2.18327609e-01 8.23953181e-01 + 3.42430842e-01 1.37098055e-01 + 1.28658034e+00 -9.11357320e-01 + 2.06914465e+00 -6.67556382e-01 + -6.69451020e-01 -6.38605102e-01 + -2.09312398e+00 1.16743634e+00 + -3.63778357e+00 1.91919157e+00 + 8.74685911e-01 -1.09931208e+00 + -3.91496791e+00 1.00808357e+00 + 1.29621330e+00 -8.32239802e-01 + 9.00222045e-01 -1.31159793e+00 + -1.12242062e+00 1.98517079e-01 + -3.71932852e-01 1.31667093e-01 + -2.23829610e+00 1.26328346e+00 + -2.08365062e+00 9.93385336e-01 + -1.91082720e+00 7.45866855e-01 + 4.38024917e+00 -2.05901118e+00 + -2.28872886e+00 6.85279335e-01 + 1.01274497e-01 -3.26227153e-01 + -5.04447572e-01 -3.18619513e-01 + 1.28537006e+00 -1.04573551e+00 + -7.83175212e-01 1.54791645e-01 + -3.89239175e+00 1.60017929e+00 + -8.87877111e-01 -1.04968005e-01 + 9.32215179e-01 -5.58691113e-01 + -6.44977127e-01 -2.23018375e-01 + 1.10141900e+00 -1.00666432e+00 + 2.92755687e-01 -1.45480350e-01 + 7.73580681e-01 -2.21150567e-01 + -1.40873709e+00 7.61548044e-01 + -8.89031805e-01 -3.48542923e-01 + 4.16844267e-01 -2.39914494e-01 + -4.64265832e-01 7.29581138e-01 + 1.99835179e+00 -7.70542813e-01 + 4.20523191e-02 -2.18783563e-01 + -6.32611758e-01 -3.09926115e-01 + 6.82912198e-02 -8.48327050e-01 + 1.92425229e+00 -1.37876951e+00 + 3.49461782e+00 -1.88354255e+00 + -3.25209026e+00 1.49809395e+00 + 6.59273182e-01 -2.37435654e-01 + -1.15517300e+00 8.46134387e-01 + 1.26756151e+00 -4.58988026e-01 + -3.99178418e+00 2.04153008e+00 + 7.05687841e-01 -6.83433306e-01 + -1.61997342e+00 8.16577004e-01 + -3.89750399e-01 4.29753250e-01 + -2.53026432e-01 4.92861432e-01 + -3.16788324e+00 4.44285524e-01 + -7.86248901e-01 1.12753716e+00 + -3.02351433e+00 1.28419015e+00 + -1.30131355e+00 1.71226678e+00 + -4.08843475e+00 1.62063214e+00 + -3.09209403e+00 1.19958520e+00 + 1.49102271e+00 -1.11834864e+00 + -3.18059348e+00 5.74587042e-01 + 2.06054867e+00 3.25797860e-03 + -3.50999200e+00 2.02412428e+00 + -8.26610023e-01 3.46528211e-01 + 2.00546034e+00 -4.07333110e-01 + -9.69941653e-01 4.80953753e-01 + 4.47925660e+00 -2.33127314e+00 + 2.03845790e+00 -9.90439915e-01 + -1.11349191e+00 4.31183918e-01 + -4.03628396e+00 1.68509679e+00 + -1.48177601e+00 7.74322088e-01 + 3.07369385e+00 -9.57465886e-01 + 2.39011286e+00 -6.44506921e-01 + 2.91561991e+00 -8.78627328e-01 + 1.10212733e+00 -4.21637388e-01 + 5.31985231e-01 -6.17445696e-01 + -6.82340929e-01 -2.93529716e-01 + 1.94290679e+00 -4.64268634e-01 + 1.92262116e+00 -7.93142835e-01 + 4.73762800e+00 -1.63654174e+00 + -3.17848641e+00 8.05791391e-01 + 4.08739432e+00 -1.80816807e+00 + -7.60648826e-01 1.24216138e-01 + -2.24716400e+00 7.90020937e-01 + 1.64284052e+00 -7.18784070e-01 + 1.04410012e-01 -7.11195880e-02 + 2.18268225e+00 -7.01767831e-01 + 2.06218013e+00 -8.70251746e-01 + -1.35266581e+00 7.08456358e-01 + -1.38157779e+00 5.14401086e-01 + -3.28326008e+00 1.20988399e+00 + 8.85358917e-01 -8.12213495e-01 + -2.34067500e+00 3.67657353e-01 + 3.96878127e+00 -1.66841450e+00 + 1.36518053e+00 -8.33436812e-01 + 5.25771988e-01 -5.06121987e-01 + -2.25948361e+00 1.30663765e+00 + -2.57662070e+00 6.32114628e-01 + -3.43134685e+00 2.38106008e+00 + 2.31571924e+00 -1.56566818e+00 + -2.95397202e+00 1.05661888e+00 + -1.35331242e+00 6.76383411e-01 + 1.40977132e+00 -1.17775938e+00 + 1.52561996e+00 -9.83147176e-01 + 2.26550832e+00 -2.10464123e-02 + 6.23371684e-01 -5.30768122e-01 + -4.42356624e-01 9.72226986e-01 + 2.31517901e+00 -1.08468105e+00 + 1.97236640e+00 -1.42016619e+00 + 3.18618687e+00 -1.45056343e+00 + -2.75880360e+00 5.40254980e-01 + -1.92916581e+00 1.45029864e-01 + 1.90022524e+00 -6.03805754e-01 + -1.05446211e+00 5.74361752e-01 + 1.45990390e+00 -9.28233993e-01 + 5.14960557e+00 -2.07564096e+00 + -7.53104842e-01 1.55876958e-01 + 8.09490983e-02 -8.58886384e-02 + -1.56894969e+00 4.53497227e-01 + 1.36944658e-01 5.60670875e-01 + -5.32635329e-01 4.40309945e-01 + 1.32507853e+00 -5.83670099e-01 + 1.20676031e+00 -8.02296831e-01 + -3.65023422e+00 1.17211368e+00 + 1.53393850e+00 -6.17771312e-01 + -3.99977129e+00 1.71415137e+00 + 5.70705058e-01 -4.60771539e-01 + -2.20608002e+00 1.07866596e+00 + -1.09040244e+00 6.77441076e-01 + -5.09886482e-01 -1.97282128e-01 + -1.58062785e+00 6.18333697e-01 + -1.53295020e+00 4.02168701e-01 + -5.18580598e-01 2.25767177e-01 + 1.59514316e+00 -2.54983617e-01 + -5.91938655e+00 2.68223782e+00 + 2.84200509e+00 -1.04685313e+00 + 1.31298664e+00 -1.16672614e+00 + -2.36660033e+00 1.81359460e+00 + 6.94163290e-02 3.76658816e-01 + 2.33973934e+00 -8.33173023e-01 + -8.24640389e-01 7.83717285e-01 + -1.02888281e+00 1.04680766e+00 + 1.34750745e+00 -5.89568160e-01 + -2.48761231e+00 7.44199284e-01 + -1.04501559e+00 4.72326911e-01 + -3.14610089e+00 1.89843692e+00 + 2.13003416e-01 5.76633620e-01 + -1.69239608e+00 5.66070021e-01 + 1.80491280e+00 -9.31701080e-01 + -6.94362572e-02 6.96026587e-01 + 1.36502578e+00 -6.85599000e-02 + -7.76764337e-01 3.64328661e-01 + -2.67322167e+00 6.80150021e-01 + 1.84338485e+00 -1.18487494e+00 + 2.88009231e+00 -1.25700411e+00 + 1.17114433e+00 -7.69727080e-01 + 2.11576167e+00 2.81502116e-01 + -1.51470088e+00 2.61553540e-01 + 1.18923669e-01 -1.17890202e-01 + 4.48359786e+00 -1.81427466e+00 + -1.27055948e+00 9.92388998e-01 + -8.00276606e-01 9.11326621e-02 + 7.51764024e-01 -1.03676498e-01 + 1.35769348e-01 -2.11470084e-01 + 2.50731332e+00 -1.12418270e+00 + -2.49752781e-01 7.81224033e-02 + -6.23037902e-01 3.16599691e-01 + -3.93772902e+00 1.37195391e+00 + 1.74256361e+00 -1.12363582e+00 + -1.49737281e+00 5.98828310e-01 + 7.75592115e-01 -4.64733802e-01 + -2.26027693e+00 1.36991118e+00 + -1.62849836e+00 7.36899107e-01 + 2.36850751e+00 -9.32126872e-01 + 5.86169745e+00 -2.49342512e+00 + -5.37092226e-01 1.23821274e+00 + 2.80535867e+00 -1.93363302e+00 + -1.77638106e+00 9.10050276e-01 + 3.02692018e+00 -1.60774676e+00 + 1.97833084e+00 -1.50636531e+00 + 9.09168906e-01 -8.83799359e-01 + 2.39769655e+00 -7.56977869e-01 + 1.47283981e+00 -1.06749890e+00 + 2.92060943e-01 -6.07040605e-01 + -2.09278201e+00 7.71858590e-01 + 7.10015905e-01 -5.42768432e-01 + -2.16826169e-01 1.56897896e-01 + 4.56288247e+00 -2.08912680e+00 + -6.63374020e-01 6.67325183e-01 + 1.80564442e+00 -9.76366134e-01 + 3.28720168e+00 -4.66575145e-01 + -1.60463695e-01 -2.58428153e-01 + 1.78590750e+00 -3.96427146e-01 + 2.75950306e+00 -1.82102856e+00 + -1.18234310e+00 6.28073320e-01 + 4.11415835e+00 -2.33551216e+00 + 1.38721004e+00 -2.77450622e-01 + -2.94903545e+00 1.74813352e+00 + 8.67290400e-01 -6.51667894e-01 + 2.70022274e+00 -8.11832480e-01 + -2.06766146e+00 8.24047249e-01 + 3.90717142e+00 -1.20155758e+00 + -2.95102809e+00 1.36667968e+00 + 6.08815147e+00 -2.60737974e+00 + 2.78576476e+00 -7.86628755e-01 + -3.26258407e+00 1.09302450e+00 + 1.59849422e+00 -1.09705202e+00 + -2.50600710e-01 1.63243175e-01 + -4.90477087e-01 -4.57729572e-01 + -1.24837181e+00 3.22157840e-01 + -2.46341049e+00 1.06517849e+00 + 9.62880751e-01 4.56962496e-01 + 3.99964487e-01 2.07472802e-01 + 6.36657705e-01 -3.46400942e-02 + 4.91231407e-02 -1.40289235e-02 + -4.66683524e-02 -3.72326100e-01 + -5.22049702e-01 -1.70440260e-01 + 5.27062938e-01 -2.32628395e-01 + -2.69440318e+00 1.18914874e+00 + 3.65087539e+00 -1.53427267e+00 + -1.16546364e-01 4.93245392e-02 + 7.55931384e-01 -3.02980139e-01 + 2.06338745e+00 -6.24841225e-01 + 1.31177908e-01 7.29338183e-01 + 1.48021784e+00 -6.39509896e-01 + -5.98656707e-01 2.84525503e-01 + -2.18611080e+00 1.79549812e+00 + -2.91673624e+00 2.15772237e-01 + -8.95591350e-01 7.68250538e-01 + 1.36139762e+00 -1.93845144e-01 + 5.45730414e+00 -2.28114404e+00 + 3.22747247e-01 9.33582332e-01 + -1.46384504e+00 1.12801186e-01 + 4.26728166e-01 -2.33481242e-01 + -1.41327270e+00 8.16103740e-01 + -2.53998067e-01 1.44906646e-01 + -1.32436467e+00 1.87556361e-01 + -3.77313086e+00 1.32896038e+00 + 3.77651731e+00 -1.76548043e+00 + -2.45297093e+00 1.32571926e+00 + -6.55900588e-01 3.56921462e-01 + 9.25558722e-01 -4.51988954e-01 + 1.20732231e+00 -3.02821614e-01 + 3.72660154e-01 -1.89365208e-01 + -1.77090939e+00 9.18087975e-01 + 3.01127567e-01 2.67965829e-01 + -1.76708900e+00 4.62069259e-01 + -2.71812099e+00 1.57233508e+00 + -5.35297633e-01 4.99231535e-01 + 1.50507631e+00 -9.85763646e-01 + 3.00424787e+00 -1.29837562e+00 + -4.99311105e-01 3.91086482e-01 + 1.30125207e+00 -1.26247924e-01 + 4.01699483e-01 -4.46909391e-01 + -1.33635257e+00 5.12068703e-01 + 1.39229757e+00 -9.10974858e-01 + -1.74229508e+00 1.49475978e+00 + -1.21489414e+00 4.04193753e-01 + -3.36537605e-01 -6.74335427e-01 + -2.79186828e-01 8.48314720e-01 + -2.03080140e+00 1.66599815e+00 + -3.53064281e-01 -7.68582906e-04 + -5.30305657e+00 2.91091546e+00 + -1.20049972e+00 8.26578358e-01 + 2.95906989e-01 2.40215920e-01 + -1.42955534e+00 4.63480310e-01 + -1.87856619e+00 8.21459385e-01 + -2.71124720e+00 1.80246843e+00 + -3.06933780e+00 1.22235760e+00 + 5.21935582e-01 -1.27298218e+00 + -1.34175797e+00 7.69018937e-01 + -1.81962785e+00 1.15528991e+00 + -3.99227550e-01 2.93821598e-01 + 1.22533179e+00 -4.73846323e-01 + -2.08068359e-01 -1.75039817e-01 + -2.03068526e+00 1.50370503e+00 + -3.27606113e+00 1.74906330e+00 + -4.37802587e-01 -2.26956048e-01 + -7.69774213e-02 -3.54922468e-01 + 6.47160749e-02 -2.07334721e-01 + -1.37791524e+00 4.43766709e-01 + 3.29846803e+00 -1.04060799e+00 + -3.63704046e+00 1.05800226e+00 + -1.26716116e+00 1.13077353e+00 + 1.98549075e+00 -1.31864807e+00 + 1.85159500e+00 -5.78629560e-01 + -1.55295206e+00 1.23655857e+00 + 6.76026255e-01 9.18824125e-02 + 1.23418960e+00 -4.68162027e-01 + 2.43186642e+00 -9.22422440e-01 + -3.18729701e+00 1.77582673e+00 + -4.02945613e+00 1.14303496e+00 + -1.92694576e-01 1.03301431e-01 + 1.89554730e+00 -4.60128096e-01 + -2.55626581e+00 1.16057084e+00 + 6.89144365e-01 -9.94982900e-01 + -4.44680606e+00 2.19751983e+00 + -3.15196193e+00 1.18762993e+00 + -1.17434977e+00 1.04534656e+00 + 8.58386984e-02 -1.03947487e+00 + 3.33354973e-01 5.54813610e-01 + -9.37631808e-01 3.33450150e-01 + -2.50232471e+00 5.39720635e-01 + 1.03611949e+00 -7.16304095e-02 + -2.05556816e-02 -3.28992265e-01 + -2.24176201e+00 1.13077506e+00 + 4.53583688e+00 -1.10710212e+00 + 4.77389762e-01 -8.99445512e-01 + -2.69075551e+00 6.83176866e-01 + -2.21779724e+00 1.16916849e+00 + -1.09669056e+00 2.10044765e-01 + -8.45367920e-01 -8.45951423e-02 + 4.37558941e-01 -6.95904256e-01 + 1.84884195e+00 -1.71205136e-01 + -8.36371957e-01 5.62862478e-01 + 1.27786531e+00 -1.33362147e+00 + 2.90684492e+00 -7.49892184e-01 + -3.38652716e+00 1.51180670e+00 + -1.30945978e+00 7.09261928e-01 + -7.50471924e-01 -5.24637889e-01 + 1.18580718e+00 -9.97943971e-04 + -7.55395645e+00 3.19273590e+00 + 1.72822535e+00 -1.20996962e+00 + 5.67374320e-01 6.19573416e-01 + -2.99163781e+00 1.79721534e+00 + 1.49862187e+00 -6.05631846e-02 + 1.79503506e+00 -4.90419706e-01 + 3.85626054e+00 -1.95396324e+00 + -9.39188410e-01 7.96498057e-01 + 2.91986664e+00 -1.29392724e+00 + -1.54265750e+00 6.40727933e-01 + 1.14919794e+00 1.20834257e-01 + 2.00936817e+00 -1.53728359e+00 + 3.72468420e+00 -1.38704612e+00 + -1.27794802e+00 3.48543179e-01 + 3.63294077e-01 5.70623314e-01 + 1.49381016e+00 -6.04500534e-01 + 2.98912256e+00 -1.72295726e+00 + -1.80833817e+00 2.94907625e-01 + -3.19669622e+00 1.31888700e+00 + 1.45889401e+00 -8.88448639e-01 + -2.80045388e+00 1.01207060e+00 + -4.78379567e+00 1.48646520e+00 + 2.25510003e+00 -7.13372461e-01 + -9.74441433e-02 -2.17766373e-01 + 2.64468496e-01 -3.60842698e-01 + -5.98821713e+00 3.20197892e+00 + 2.67030213e-01 -5.36386416e-01 + 2.24546960e+00 -8.13464649e-01 + -4.89171414e-01 3.86255031e-01 + -7.45713706e-01 6.29800380e-01 + -3.30460503e-01 3.85127284e-01 + -4.19588147e+00 1.52793198e+00 + 5.42078582e-01 -2.61642741e-02 + 4.24938513e-01 -5.72936751e-01 + 2.82717288e+00 -6.75355024e-01 + -1.44741788e+00 5.03578028e-01 + -1.65547573e+00 7.76444277e-01 + 2.20361170e+00 -1.40835680e+00 + -3.69540235e+00 2.32953767e+00 + -1.41909357e-01 2.28989778e-01 + 1.92838879e+00 -8.72525737e-01 + 1.40708100e+00 -6.81849638e-02 + 1.24988112e+00 -1.39470590e-01 + -2.39435855e+00 7.26587655e-01 + 7.03985028e-01 4.85403277e-02 + 4.05214529e+00 -9.16928318e-01 + 3.74198837e-01 -5.04192358e-01 + -8.43374127e-01 2.36064018e-01 + -3.32253349e-01 7.47840055e-01 + -6.03725210e+00 1.95173337e+00 + 4.60829865e+00 -1.51191309e+00 + -1.46247098e+00 1.11140916e+00 + -9.60111157e-01 -1.23189114e-01 + -7.49613187e-01 4.53614129e-01 + -5.77838219e-01 2.07366469e-02 + 8.07652950e-01 -5.16272662e-01 + -6.02556049e-01 5.05318649e-01 + -1.28712445e-01 2.57836512e-01 + -5.27662820e+00 2.11790737e+00 + 5.40819308e+00 -2.15366022e+00 + 9.37742513e-02 -1.60221751e-01 + 4.55902865e+00 -1.24646307e+00 + -9.06582589e-01 1.92928110e-01 + 2.99928996e+00 -8.04301218e-01 + -3.24317381e+00 1.80076061e+00 + 3.20421743e-01 8.76524679e-01 + -5.29606705e-01 -3.16717696e-01 + -1.77264560e+00 7.52686776e-01 + -1.51706824e+00 8.43755103e-01 + 1.52759111e+00 -7.86814243e-01 + 4.74845617e-01 4.21319700e-01 + 6.97829149e-01 -8.15664881e-01 + 3.09564973e+00 -1.06202469e+00 + 2.95320379e+00 -1.98963943e+00 + -4.23033224e+00 1.41013338e+00 + 1.48576206e+00 8.02908511e-02 + 4.52041627e+00 -2.04620399e+00 + 6.58403922e-01 -7.60781799e-01 + 2.10667543e-01 1.15241731e-01 + 1.77702583e+00 -8.10271859e-01 + 2.41277385e+00 -1.46972042e+00 + 1.50685525e+00 -1.99272545e-01 + 7.61665522e-01 -4.11276152e-01 + 1.18352312e+00 -9.59908608e-01 + -3.32031305e-01 8.07500132e-02 + 1.16813118e+00 -1.73095194e-01 + 1.18363346e+00 -5.41565052e-01 + 5.17702179e-01 -7.62442035e-01 + 4.57401006e-01 -1.45951115e-02 + 1.49377115e-01 2.99571605e-01 + 1.40399453e+00 -1.30160353e+00 + 5.26231567e-01 3.52783752e-01 + -1.91136514e+00 4.24228635e-01 + 1.74156701e+00 -9.92076776e-01 + -4.89323391e+00 2.32483507e+00 + 2.54011209e+00 -8.80366295e-01 + -5.56925706e-01 1.48842026e-01 + -2.35904668e+00 9.60474853e-01 + 1.42216971e+00 -4.67062761e-01 + -1.10809680e+00 7.68684300e-01 + 4.09674726e+00 -1.90795680e+00 + -2.23048923e+00 9.03812542e-01 + 6.57025763e-01 1.36514871e-01 + 2.10944145e+00 -9.78897838e-02 + 1.22552525e+00 -2.50303867e-01 + 2.84620103e-01 -5.30164020e-01 + -2.13562585e+00 1.03503056e+00 + 1.32414902e-01 -8.14190240e-03 + -5.82433561e-01 3.21020292e-01 + -5.06473247e-01 3.11530419e-01 + 1.57162465e+00 -1.20763919e+00 + -1.43155284e+00 -2.51203698e-02 + -1.47093713e+00 -1.39620999e-01 + -2.65765643e+00 1.06091403e+00 + 2.45992927e+00 -5.88815836e-01 + -1.28440162e+00 -1.99377398e-01 + 6.11257504e-01 -3.73577401e-01 + -3.46606103e-01 6.06081290e-01 + 3.76687505e+00 -8.80181424e-01 + -1.03725103e+00 1.45177517e+00 + 2.76659936e+00 -1.09361320e+00 + -3.61311296e+00 9.75032455e-01 + 3.22878655e+00 -9.69497365e-01 + 1.43560379e+00 -5.52524585e-01 + 2.94042153e+00 -1.79747037e+00 + 1.30739580e+00 2.47989248e-01 + -4.05056982e-01 1.22831715e+00 + -2.25827421e+00 2.30604626e-01 + 3.69262926e-01 4.32714650e-02 + -5.52064063e-01 6.07806340e-01 + 7.03325987e+00 -2.17956730e+00 + -2.37823835e-01 -8.28068639e-01 + -4.84279888e-01 5.67765194e-01 + -3.15863410e+00 1.02241617e+00 + -3.39561593e+00 1.36876374e+00 + -2.78482934e+00 6.81641104e-01 + -4.37604334e+00 2.23826340e+00 + -2.54049692e+00 8.22676745e-01 + 3.73264822e+00 -9.93498732e-01 + -3.49536064e+00 1.84771519e+00 + 9.81801604e-01 -5.21278776e-01 + 1.52996831e+00 -1.27386206e+00 + -9.23490293e-01 5.29099482e-01 + -2.76999461e+00 9.24831872e-01 + -3.30029834e-01 -2.49645555e-01 + -1.71156166e+00 5.44940854e-01 + -2.37009487e+00 5.83826982e-01 + -3.03216865e+00 1.04922722e+00 + -2.19539936e+00 1.37558730e+00 + 1.15350207e+00 -6.15318535e-01 + 4.62011792e+00 -2.46714517e+00 + 1.52627952e-02 -1.00618283e-01 + -1.10399342e+00 4.87413533e-01 + 3.55448194e+00 -9.10394190e-01 + -5.21890321e+00 2.44710745e+00 + 1.54289749e+00 -6.54269311e-01 + 2.67935674e+00 -9.92758863e-01 + 1.05801310e+00 2.60054285e-02 + 1.52509097e+00 -4.08768600e-01 + 3.27576917e+00 -1.28769406e+00 + 1.71008412e-01 -2.68739994e-01 + -9.83351344e-04 7.02495897e-02 + -7.60795056e-03 1.61968285e-01 + -1.80620472e+00 4.24934471e-01 + 2.32023297e-02 -2.57284559e-01 + 3.98219478e-01 -4.65361935e-01 + 6.63476988e-01 -3.29823196e-02 + 4.00154707e+00 -1.01792211e+00 + -1.50286870e+00 9.46875359e-01 + -2.22717585e+00 7.50636195e-01 + -3.47381508e-01 -6.51596975e-01 + 2.08076453e+00 -8.22800165e-01 + 2.05099963e+00 -4.00868250e-01 + 3.52576988e-02 -2.54418565e-01 + 1.57342042e+00 -7.62166492e-02 + -1.47019722e+00 3.40861172e-01 + -1.21156090e+00 3.21891246e-01 + 3.79729047e+00 -1.54350764e+00 + 1.26459678e-02 6.99203693e-01 + 1.53974177e-01 4.68643204e-01 + -1.73923561e-01 -1.26229768e-01 + 4.54644993e+00 -2.13951783e+00 + 1.46022547e-01 -4.57084165e-01 + 6.50048037e+00 -2.78872609e+00 + -1.51934912e+00 1.03216768e+00 + -3.06483575e+00 1.81101446e+00 + -2.38212125e+00 9.19559042e-01 + -1.81319611e+00 8.10545112e-01 + 1.70951294e+00 -6.10712680e-01 + 1.67974156e+00 -1.51241453e+00 + -5.94795113e+00 2.56893813e+00 + 3.62633110e-01 -7.46965304e-01 + -2.44042594e+00 8.52761797e-01 + 3.32412550e+00 -1.28439899e+00 + 4.74860766e+00 -1.72821964e+00 + 1.29072541e+00 -8.24872902e-01 + -1.69450702e+00 4.09600876e-01 + 1.29705411e+00 1.22300809e-01 + -2.63597613e+00 8.55612913e-01 + 9.28467301e-01 -2.63550114e-02 + 2.44670264e+00 -4.10123002e-01 + 1.06408206e+00 -5.03361942e-01 + 5.12384049e-02 -1.27116595e-02 + -1.06731272e+00 -1.76205029e-01 + -9.45454582e-01 3.74404917e-01 + 2.54343689e+00 -7.13810545e-01 + -2.54460335e+00 1.31590265e+00 + 1.89864233e+00 -3.98436339e-01 + -1.93990133e+00 6.01474630e-01 + -1.35938824e+00 4.00751788e-01 + 2.38567018e+00 -6.13904880e-01 + 2.18748050e-01 2.62631712e-01 + -2.01388788e+00 1.41474031e+00 + 2.74014581e+00 -1.27448105e+00 + -2.13828583e+00 1.13616144e+00 + 5.98730932e+00 -2.53430080e+00 + -1.72872795e+00 1.53702057e+00 + -2.53263962e+00 1.27342410e+00 + 1.34326968e+00 -1.99395088e-01 + 3.83352666e-01 -1.25683065e-01 + -2.35630657e+00 5.54116983e-01 + -1.94900838e+00 5.76270178e-01 + -1.36699108e+00 -3.40904824e-01 + -2.34727346e+00 -1.93054940e-02 + -3.82779777e+00 1.83025664e+00 + -4.31602080e+00 9.21605705e-01 + 5.54098133e-01 2.33991419e-01 + -4.53591188e+00 1.99833353e+00 + -3.92715909e+00 1.83231482e+00 + 3.91344440e-01 -1.11355111e-01 + 3.48576363e+00 -1.41379449e+00 + -1.42858690e+00 3.84532286e-01 + 1.79519859e+00 -9.23486448e-01 + 8.49691242e-01 -1.76551084e-01 + 1.53618138e+00 8.23835015e-02 + 5.91476520e-02 3.88296940e-02 + 1.44837346e+00 -7.24097604e-01 + -6.79008418e-01 4.04078097e-01 + 2.87555510e+00 -9.51825076e-01 + -1.12379101e+00 2.93457714e-01 + 1.45263980e+00 -6.01960544e-01 + -2.55741621e-01 9.26233518e-01 + 3.54570714e+00 -1.41521877e+00 + -1.61542388e+00 6.57844512e-01 + -3.22844269e-01 3.02823546e-01 + 1.03523913e+00 -6.92730711e-01 + 1.11084909e+00 -3.50823642e-01 + 3.41268693e+00 -1.90865862e+00 + 7.67062858e-01 -9.48792160e-01 + -5.49798016e+00 1.71139960e+00 + 1.14865798e+00 -6.12669150e-01 + -2.18256680e+00 7.78634462e-01 + 4.78857389e+00 -2.55555085e+00 + -1.85555569e+00 8.04311615e-01 + -4.22278799e+00 2.01162524e+00 + -1.56556149e+00 1.54353907e+00 + -3.11527864e+00 1.65973526e+00 + 2.66342611e+00 -1.20449402e+00 + 1.57635314e+00 -1.48716308e-01 + -6.35606865e-01 2.59701180e-01 + 1.02431976e+00 -6.76929904e-01 + 1.12973772e+00 1.49473892e-02 + -9.12758116e-01 2.21533933e-01 + -2.98014470e+00 1.71651189e+00 + 2.74016965e+00 -9.47893923e-01 + -3.47830591e+00 1.34941430e+00 + 1.74757562e+00 -3.72503752e-01 + 5.55820383e-01 -6.47992466e-01 + -1.19871928e+00 9.82429151e-01 + -2.53040133e+00 2.10671307e+00 + -1.94085605e+00 1.38938137e+00 diff --git a/data/mllib/sample_fpgrowth.txt b/data/mllib/sample_fpgrowth.txt new file mode 100644 index 0000000000000..c451583e51317 --- /dev/null +++ b/data/mllib/sample_fpgrowth.txt @@ -0,0 +1,6 @@ +r z h k p +z y x w v u t s +s x o n r +x z y m t s q e +z +x z y r q t p diff --git a/data/mllib/sample_isotonic_regression_data.txt b/data/mllib/sample_isotonic_regression_data.txt new file mode 100644 index 0000000000000..d257b509d4d37 --- /dev/null +++ b/data/mllib/sample_isotonic_regression_data.txt @@ -0,0 +1,100 @@ +0.24579296,0.01 +0.28505864,0.02 +0.31208567,0.03 +0.35900051,0.04 +0.35747068,0.05 +0.16675166,0.06 +0.17491076,0.07 +0.04181540,0.08 +0.04793473,0.09 +0.03926568,0.10 +0.12952575,0.11 +0.00000000,0.12 +0.01376849,0.13 +0.13105558,0.14 +0.08873024,0.15 +0.12595614,0.16 +0.15247323,0.17 +0.25956145,0.18 +0.20040796,0.19 +0.19581846,0.20 +0.15757267,0.21 +0.13717491,0.22 +0.19020908,0.23 +0.19581846,0.24 +0.20091790,0.25 +0.16879143,0.26 +0.18510964,0.27 +0.20040796,0.28 +0.29576747,0.29 +0.43396226,0.30 +0.53391127,0.31 +0.52116267,0.32 +0.48546660,0.33 +0.49209587,0.34 +0.54156043,0.35 +0.59765426,0.36 +0.56144824,0.37 +0.58592555,0.38 +0.52983172,0.39 +0.50178480,0.40 +0.52626211,0.41 +0.58286588,0.42 +0.64660887,0.43 +0.68077511,0.44 +0.74298827,0.45 +0.64864865,0.46 +0.67261601,0.47 +0.65782764,0.48 +0.69811321,0.49 +0.63029067,0.50 +0.61601224,0.51 +0.63233044,0.52 +0.65323814,0.53 +0.65323814,0.54 +0.67363590,0.55 +0.67006629,0.56 +0.51555329,0.57 +0.50892402,0.58 +0.33299337,0.59 +0.36206017,0.60 +0.43090260,0.61 +0.45996940,0.62 +0.56348802,0.63 +0.54920959,0.64 +0.48393677,0.65 +0.48495665,0.66 +0.46965834,0.67 +0.45181030,0.68 +0.45843957,0.69 +0.47118817,0.70 +0.51555329,0.71 +0.58031617,0.72 +0.55481897,0.73 +0.56297807,0.74 +0.56603774,0.75 +0.57929628,0.76 +0.64762876,0.77 +0.66241713,0.78 +0.69301377,0.79 +0.65119837,0.80 +0.68332483,0.81 +0.66598674,0.82 +0.73890872,0.83 +0.73992861,0.84 +0.84242733,0.85 +0.91330954,0.86 +0.88016318,0.87 +0.90719021,0.88 +0.93115757,0.89 +0.93115757,0.90 +0.91942886,0.91 +0.92911780,0.92 +0.95665477,0.93 +0.95002550,0.94 +0.96940337,0.95 +1.00000000,0.96 +0.89801122,0.97 +0.90311066,0.98 +0.90362060,0.99 +0.83477817,1.0 \ No newline at end of file diff --git a/data/mllib/sample_lda_data.txt b/data/mllib/sample_lda_data.txt new file mode 100644 index 0000000000000..2e76702ca9d67 --- /dev/null +++ b/data/mllib/sample_lda_data.txt @@ -0,0 +1,12 @@ +1 2 6 0 2 3 1 1 0 0 3 +1 3 0 1 3 0 0 2 0 0 1 +1 4 1 0 0 4 9 0 1 2 0 +2 1 0 3 0 0 5 0 2 3 9 +3 1 1 9 3 0 2 0 0 1 3 +4 2 0 3 4 5 1 1 1 4 0 +2 1 0 3 0 0 5 0 2 2 9 +1 1 1 9 2 1 2 0 0 1 3 +4 4 0 3 4 2 1 3 0 0 0 +2 8 2 0 3 0 2 0 2 7 2 +1 1 1 9 0 2 2 0 0 3 3 +4 1 0 0 4 5 1 3 0 1 0 diff --git a/dev/check-license b/dev/check-license index 72b1013479964..39943f882b6ca 100755 --- a/dev/check-license +++ b/dev/check-license @@ -27,17 +27,17 @@ acquire_rat_jar () { if [[ ! -f "$rat_jar" ]]; then # Download rat launch jar if it hasn't been downloaded yet if [ ! -f "$JAR" ]; then - # Download - printf "Attempting to fetch rat\n" - JAR_DL="${JAR}.part" - if hash curl 2>/dev/null; then - curl --silent "${URL}" > "$JAR_DL" && mv "$JAR_DL" "$JAR" - elif hash wget 2>/dev/null; then - wget --quiet ${URL} -O "$JAR_DL" && mv "$JAR_DL" "$JAR" - else - printf "You do not have curl or wget installed, please install rat manually.\n" - exit -1 - fi + # Download + printf "Attempting to fetch rat\n" + JAR_DL="${JAR}.part" + if [ $(command -v curl) ]; then + curl -L --silent "${URL}" > "$JAR_DL" && mv "$JAR_DL" "$JAR" + elif [ $(command -v wget) ]; then + wget --quiet ${URL} -O "$JAR_DL" && mv "$JAR_DL" "$JAR" + else + printf "You do not have curl or wget installed, please install rat manually.\n" + exit -1 + fi fi unzip -tq $JAR &> /dev/null diff --git a/dev/create-release/create-release.sh b/dev/create-release/create-release.sh index b1b8cb44e098b..da15ce3e0e2f7 100755 --- a/dev/create-release/create-release.sh +++ b/dev/create-release/create-release.sh @@ -22,8 +22,9 @@ # Expects to be run in a totally empty directory. # # Options: -# --package-only only packages an existing release candidate -# +# --skip-create-release Assume the desired release tag already exists +# --skip-publish Do not publish to Maven central +# --skip-package Do not package and upload binary artifacts # Would be nice to add: # - Send output to stderr and have useful logging in stdout @@ -51,7 +52,7 @@ set -e GIT_TAG=v$RELEASE_VERSION-$RC_NAME -if [[ ! "$@" =~ --package-only ]]; then +if [[ ! "$@" =~ --skip-create-release ]]; then echo "Creating release commit and publishing to Apache repository" # Artifact publishing git clone https://$ASF_USERNAME:$ASF_PASSWORD@git-wip-us.apache.org/repos/asf/spark.git \ @@ -87,8 +88,15 @@ if [[ ! "$@" =~ --package-only ]]; then git commit -a -m "Preparing development version $next_ver" git push origin $GIT_TAG git push origin HEAD:$GIT_BRANCH - git checkout -f $GIT_TAG + popd + rm -rf spark +fi +if [[ ! "$@" =~ --skip-publish ]]; then + git clone https://$ASF_USERNAME:$ASF_PASSWORD@git-wip-us.apache.org/repos/asf/spark.git + pushd spark + git checkout --force $GIT_TAG + # Using Nexus API documented here: # https://support.sonatype.com/entries/39720203-Uploading-to-a-Staging-Repository-via-REST-API echo "Creating Nexus staging repository" @@ -106,7 +114,7 @@ if [[ ! "$@" =~ --package-only ]]; then clean install ./dev/change-version-to-2.11.sh - + mvn -DskipTests -Dhadoop.version=2.2.0 -Dyarn.version=2.2.0 \ -Dscala-2.11 -Pyarn -Phive -Phadoop-2.2 -Pspark-ganglia-lgpl -Pkinesis-asl \ clean install @@ -122,8 +130,14 @@ if [[ ! "$@" =~ --package-only ]]; then for file in $(find . -type f) do echo $GPG_PASSPHRASE | gpg --passphrase-fd 0 --output $file.asc --detach-sig --armour $file; - gpg --print-md MD5 $file > $file.md5; - gpg --print-md SHA1 $file > $file.sha1 + if [ $(command -v md5) ]; then + # Available on OS X; -q to keep only hash + md5 -q $file > $file.md5 + else + # Available on Linux; cut to keep only hash + md5sum $file | cut -f1 -d' ' > $file.md5 + fi + shasum -a 1 $file | cut -f1 -d' ' > $file.sha1 done nexus_upload=$NEXUS_ROOT/deployByRepositoryId/$staged_repo_id @@ -149,88 +163,90 @@ if [[ ! "$@" =~ --package-only ]]; then rm -rf spark fi -# Source and binary tarballs -echo "Packaging release tarballs" -git clone https://git-wip-us.apache.org/repos/asf/spark.git -cd spark -git checkout --force $GIT_TAG -release_hash=`git rev-parse HEAD` - -rm .gitignore -rm -rf .git -cd .. - -cp -r spark spark-$RELEASE_VERSION -tar cvzf spark-$RELEASE_VERSION.tgz spark-$RELEASE_VERSION -echo $GPG_PASSPHRASE | gpg --passphrase-fd 0 --armour --output spark-$RELEASE_VERSION.tgz.asc \ - --detach-sig spark-$RELEASE_VERSION.tgz -echo $GPG_PASSPHRASE | gpg --passphrase-fd 0 --print-md MD5 spark-$RELEASE_VERSION.tgz > \ - spark-$RELEASE_VERSION.tgz.md5 -echo $GPG_PASSPHRASE | gpg --passphrase-fd 0 --print-md SHA512 spark-$RELEASE_VERSION.tgz > \ - spark-$RELEASE_VERSION.tgz.sha -rm -rf spark-$RELEASE_VERSION - -make_binary_release() { - NAME=$1 - FLAGS=$2 - cp -r spark spark-$RELEASE_VERSION-bin-$NAME - - cd spark-$RELEASE_VERSION-bin-$NAME - - # TODO There should probably be a flag to make-distribution to allow 2.11 support - if [[ $FLAGS == *scala-2.11* ]]; then - ./dev/change-version-to-2.11.sh - fi - - ./make-distribution.sh --name $NAME --tgz $FLAGS 2>&1 | tee ../binary-release-$NAME.log +if [[ ! "$@" =~ --skip-package ]]; then + # Source and binary tarballs + echo "Packaging release tarballs" + git clone https://git-wip-us.apache.org/repos/asf/spark.git + cd spark + git checkout --force $GIT_TAG + release_hash=`git rev-parse HEAD` + + rm .gitignore + rm -rf .git cd .. - cp spark-$RELEASE_VERSION-bin-$NAME/spark-$RELEASE_VERSION-bin-$NAME.tgz . - rm -rf spark-$RELEASE_VERSION-bin-$NAME - - echo $GPG_PASSPHRASE | gpg --passphrase-fd 0 --armour \ - --output spark-$RELEASE_VERSION-bin-$NAME.tgz.asc \ - --detach-sig spark-$RELEASE_VERSION-bin-$NAME.tgz - echo $GPG_PASSPHRASE | gpg --passphrase-fd 0 --print-md \ - MD5 spark-$RELEASE_VERSION-bin-$NAME.tgz > \ - spark-$RELEASE_VERSION-bin-$NAME.tgz.md5 - echo $GPG_PASSPHRASE | gpg --passphrase-fd 0 --print-md \ - SHA512 spark-$RELEASE_VERSION-bin-$NAME.tgz > \ - spark-$RELEASE_VERSION-bin-$NAME.tgz.sha -} - - -make_binary_release "hadoop1" "-Phive -Phive-thriftserver -Dhadoop.version=1.0.4" & -make_binary_release "hadoop1-scala2.11" "-Phive -Dscala-2.11" & -make_binary_release "cdh4" "-Phive -Phive-thriftserver -Dhadoop.version=2.0.0-mr1-cdh4.2.0" & -make_binary_release "hadoop2.3" "-Phadoop-2.3 -Phive -Phive-thriftserver -Pyarn" & -make_binary_release "hadoop2.4" "-Phadoop-2.4 -Phive -Phive-thriftserver -Pyarn" & -make_binary_release "mapr3" "-Pmapr3 -Phive -Phive-thriftserver" & -make_binary_release "mapr4" "-Pmapr4 -Pyarn -Phive -Phive-thriftserver" & -make_binary_release "hadoop2.4-without-hive" "-Phadoop-2.4 -Pyarn" & -wait - -# Copy data -echo "Copying release tarballs" -rc_folder=spark-$RELEASE_VERSION-$RC_NAME -ssh $ASF_USERNAME@people.apache.org \ - mkdir /home/$ASF_USERNAME/public_html/$rc_folder -scp spark-* \ - $ASF_USERNAME@people.apache.org:/home/$ASF_USERNAME/public_html/$rc_folder/ - -# Docs -cd spark -build/sbt clean -cd docs -# Compile docs with Java 7 to use nicer format -JAVA_HOME=$JAVA_7_HOME PRODUCTION=1 jekyll build -echo "Copying release documentation" -rc_docs_folder=${rc_folder}-docs -ssh $ASF_USERNAME@people.apache.org \ - mkdir /home/$ASF_USERNAME/public_html/$rc_docs_folder -rsync -r _site/* $ASF_USERNAME@people.apache.org:/home/$ASF_USERNAME/public_html/$rc_docs_folder - -echo "Release $RELEASE_VERSION completed:" -echo "Git tag:\t $GIT_TAG" -echo "Release commit:\t $release_hash" -echo "Binary location:\t http://people.apache.org/~$ASF_USERNAME/$rc_folder" -echo "Doc location:\t http://people.apache.org/~$ASF_USERNAME/$rc_docs_folder" + + cp -r spark spark-$RELEASE_VERSION + tar cvzf spark-$RELEASE_VERSION.tgz spark-$RELEASE_VERSION + echo $GPG_PASSPHRASE | gpg --passphrase-fd 0 --armour --output spark-$RELEASE_VERSION.tgz.asc \ + --detach-sig spark-$RELEASE_VERSION.tgz + echo $GPG_PASSPHRASE | gpg --passphrase-fd 0 --print-md MD5 spark-$RELEASE_VERSION.tgz > \ + spark-$RELEASE_VERSION.tgz.md5 + echo $GPG_PASSPHRASE | gpg --passphrase-fd 0 --print-md SHA512 spark-$RELEASE_VERSION.tgz > \ + spark-$RELEASE_VERSION.tgz.sha + rm -rf spark-$RELEASE_VERSION + + make_binary_release() { + NAME=$1 + FLAGS=$2 + cp -r spark spark-$RELEASE_VERSION-bin-$NAME + + cd spark-$RELEASE_VERSION-bin-$NAME + + # TODO There should probably be a flag to make-distribution to allow 2.11 support + if [[ $FLAGS == *scala-2.11* ]]; then + ./dev/change-version-to-2.11.sh + fi + + ./make-distribution.sh --name $NAME --tgz $FLAGS 2>&1 | tee ../binary-release-$NAME.log + cd .. + cp spark-$RELEASE_VERSION-bin-$NAME/spark-$RELEASE_VERSION-bin-$NAME.tgz . + rm -rf spark-$RELEASE_VERSION-bin-$NAME + + echo $GPG_PASSPHRASE | gpg --passphrase-fd 0 --armour \ + --output spark-$RELEASE_VERSION-bin-$NAME.tgz.asc \ + --detach-sig spark-$RELEASE_VERSION-bin-$NAME.tgz + echo $GPG_PASSPHRASE | gpg --passphrase-fd 0 --print-md \ + MD5 spark-$RELEASE_VERSION-bin-$NAME.tgz > \ + spark-$RELEASE_VERSION-bin-$NAME.tgz.md5 + echo $GPG_PASSPHRASE | gpg --passphrase-fd 0 --print-md \ + SHA512 spark-$RELEASE_VERSION-bin-$NAME.tgz > \ + spark-$RELEASE_VERSION-bin-$NAME.tgz.sha + } + + + make_binary_release "hadoop1" "-Phive -Phive-thriftserver -Dhadoop.version=1.0.4" & + make_binary_release "hadoop1-scala2.11" "-Phive -Dscala-2.11" & + make_binary_release "cdh4" "-Phive -Phive-thriftserver -Dhadoop.version=2.0.0-mr1-cdh4.2.0" & + make_binary_release "hadoop2.3" "-Phadoop-2.3 -Phive -Phive-thriftserver -Pyarn" & + make_binary_release "hadoop2.4" "-Phadoop-2.4 -Phive -Phive-thriftserver -Pyarn" & + make_binary_release "mapr3" "-Pmapr3 -Phive -Phive-thriftserver" & + make_binary_release "mapr4" "-Pmapr4 -Pyarn -Phive -Phive-thriftserver" & + make_binary_release "hadoop2.4-without-hive" "-Phadoop-2.4 -Pyarn" & + wait + + # Copy data + echo "Copying release tarballs" + rc_folder=spark-$RELEASE_VERSION-$RC_NAME + ssh $ASF_USERNAME@people.apache.org \ + mkdir /home/$ASF_USERNAME/public_html/$rc_folder + scp spark-* \ + $ASF_USERNAME@people.apache.org:/home/$ASF_USERNAME/public_html/$rc_folder/ + + # Docs + cd spark + sbt/sbt clean + cd docs + # Compile docs with Java 7 to use nicer format + JAVA_HOME="$JAVA_7_HOME" PRODUCTION=1 RELEASE_VERSION="$RELEASE_VERSION" jekyll build + echo "Copying release documentation" + rc_docs_folder=${rc_folder}-docs + ssh $ASF_USERNAME@people.apache.org \ + mkdir /home/$ASF_USERNAME/public_html/$rc_docs_folder + rsync -r _site/* $ASF_USERNAME@people.apache.org:/home/$ASF_USERNAME/public_html/$rc_docs_folder + + echo "Release $RELEASE_VERSION completed:" + echo "Git tag:\t $GIT_TAG" + echo "Release commit:\t $release_hash" + echo "Binary location:\t http://people.apache.org/~$ASF_USERNAME/$rc_folder" + echo "Doc location:\t http://people.apache.org/~$ASF_USERNAME/$rc_docs_folder" +fi diff --git a/dev/merge_spark_pr.py b/dev/merge_spark_pr.py index dfa924d2aa0ba..3062e9c3c6651 100755 --- a/dev/merge_spark_pr.py +++ b/dev/merge_spark_pr.py @@ -244,6 +244,8 @@ def resolve_jira_issue(merge_branches, comment, default_jira_id=""): versions = asf_jira.project_versions("SPARK") versions = sorted(versions, key=lambda x: x.name, reverse=True) versions = filter(lambda x: x.raw['released'] is False, versions) + # Consider only x.y.z versions + versions = filter(lambda x: re.match('\d+\.\d+\.\d+', x.name), versions) default_fix_versions = map(lambda x: fix_version_from_branch(x, versions).name, merge_branches) for v in default_fix_versions: diff --git a/dev/run-tests b/dev/run-tests index 2257a566bb1bb..483958757a2dd 100755 --- a/dev/run-tests +++ b/dev/run-tests @@ -36,7 +36,7 @@ function handle_error () { } -# Build against the right verison of Hadoop. +# Build against the right version of Hadoop. { if [ -n "$AMPLAB_JENKINS_BUILD_PROFILE" ]; then if [ "$AMPLAB_JENKINS_BUILD_PROFILE" = "hadoop1.0" ]; then @@ -77,7 +77,7 @@ export SBT_MAVEN_PROFILES_ARGS="$SBT_MAVEN_PROFILES_ARGS -Pkinesis-asl" fi } -# Only run Hive tests if there are sql changes. +# Only run Hive tests if there are SQL changes. # Partial solution for SPARK-1455. if [ -n "$AMPLAB_JENKINS" ]; then git fetch origin master:master @@ -183,7 +183,7 @@ CURRENT_BLOCK=$BLOCK_SPARK_UNIT_TESTS if [ -n "$_SQL_TESTS_ONLY" ]; then # This must be an array of individual arguments. Otherwise, having one long string # will be interpreted as a single test, which doesn't work. - SBT_MAVEN_TEST_ARGS=("catalyst/test" "sql/test" "hive/test" "mllib/test") + SBT_MAVEN_TEST_ARGS=("catalyst/test" "sql/test" "hive/test" "hive-thriftserver/test" "mllib/test") else SBT_MAVEN_TEST_ARGS=("test") fi diff --git a/docs/_config.yml b/docs/_config.yml index e2db274e1f619..0652927a8ce9b 100644 --- a/docs/_config.yml +++ b/docs/_config.yml @@ -10,6 +10,7 @@ kramdown: include: - _static + - _modules # These allow the documentation to be updated with newer releases # of Spark, Scala, and Mesos. diff --git a/docs/_layouts/global.html b/docs/_layouts/global.html index 8841f7675d35e..efc4c612937df 100755 --- a/docs/_layouts/global.html +++ b/docs/_layouts/global.html @@ -7,7 +7,9 @@ {{ page.title }} - Spark {{site.SPARK_VERSION_SHORT}} Documentation - + {% if page.description %} + + {% endif %} {% if page.redirect %} diff --git a/docs/bagel-programming-guide.md b/docs/bagel-programming-guide.md index 7e55131754a3f..c2fe6b0e286ce 100644 --- a/docs/bagel-programming-guide.md +++ b/docs/bagel-programming-guide.md @@ -1,6 +1,7 @@ --- layout: global -title: Bagel Programming Guide +displayTitle: Bagel Programming Guide +title: Bagel --- **Bagel will soon be superseded by [GraphX](graphx-programming-guide.html); we recommend that new users try GraphX instead.** diff --git a/docs/building-spark.md b/docs/building-spark.md index fb93017861ed0..4c3988e819ad8 100644 --- a/docs/building-spark.md +++ b/docs/building-spark.md @@ -111,9 +111,9 @@ To produce a Spark package compiled with Scala 2.11, use the `-Dscala-2.11` prop dev/change-version-to-2.11.sh mvn -Pyarn -Phadoop-2.4 -Dscala-2.11 -DskipTests clean package -Scala 2.11 support in Spark is experimental and does not support a few features. -Specifically, Spark's external Kafka library and JDBC component are not yet -supported in Scala 2.11 builds. +Scala 2.11 support in Spark does not support a few features due to dependencies +which are themselves not Scala 2.11 ready. Specifically, Spark's external +Kafka library and JDBC component are not yet supported in Scala 2.11 builds. # Spark Tests in Maven @@ -137,15 +137,18 @@ We use the scala-maven-plugin which supports incremental and continuous compilat should run continuous compilation (i.e. wait for changes). However, this has not been tested extensively. A couple of gotchas to note: + * it only scans the paths `src/main` and `src/test` (see [docs](http://scala-tools.org/mvnsites/maven-scala-plugin/usage_cc.html)), so it will only work from within certain submodules that have that structure. + * you'll typically need to run `mvn install` from the project root for compilation within specific submodules to work; this is because submodules that depend on other submodules do so via the `spark-parent` module). Thus, the full flow for running continuous-compilation of the `core` submodule may look more like: - ``` + +``` $ mvn install $ cd core $ mvn scala:cc @@ -156,14 +159,6 @@ Thus, the full flow for running continuous-compilation of the `core` submodule m For help in setting up IntelliJ IDEA or Eclipse for Spark development, and troubleshooting, refer to the [wiki page for IDE setup](https://cwiki.apache.org/confluence/display/SPARK/Contributing+to+Spark#ContributingtoSpark-IDESetup). -# Building Spark Debian Packages - -The Maven build includes support for building a Debian package containing the assembly 'fat-jar', PySpark, and the necessary scripts and configuration files. This can be created by specifying the following: - - mvn -Pdeb -DskipTests clean package - -The debian package can then be found under assembly/target. We added the short commit hash to the file name so that we can distinguish individual packages built for SNAPSHOT versions. - # Running Java 8 Test Suites Running only Java 8 tests and nothing else. diff --git a/docs/configuration.md b/docs/configuration.md index 673cdb371a512..8dd2bad61344f 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -1,6 +1,7 @@ --- layout: global -title: Spark Configuration +displayTitle: Spark Configuration +title: Configuration --- * This will become a table of contents (this text will be scraped). {:toc} @@ -94,30 +95,12 @@ of the most common options to set are: - - - - - - - - - - - - + + - - - + + - - + + - + @@ -183,6 +163,14 @@ of the most common options to set are: Logs the effective SparkConf as INFO when a SparkContext is started. + + + + +
    Executor IDTotal Tasks Failed Tasks Succeeded TasksInputOutputShuffle ReadShuffle WriteShuffle Spill (Memory)Shuffle Spill (Disk) + Input Size / Records + + Output Size / Records + + + Shuffle Read Size / Records + + + Shuffle Write Size / Records + Shuffle Spill (Memory)Shuffle Spill (Disk)
    {v.failedTasks + v.succeededTasks} {v.failedTasks} {v.succeededTasks} - {Utils.bytesToString(v.inputBytes)} - {Utils.bytesToString(v.outputBytes)} - {Utils.bytesToString(v.shuffleRead)} - {Utils.bytesToString(v.shuffleWrite)} - {Utils.bytesToString(v.memoryBytesSpilled)} - {Utils.bytesToString(v.diskBytesSpilled)} + {s"${Utils.bytesToString(v.inputBytes)} / ${v.inputRecords}"} + + {s"${Utils.bytesToString(v.outputBytes)} / ${v.outputRecords}"} + + {s"${Utils.bytesToString(v.shuffleRead)} / ${v.shuffleReadRecords}"} + + {s"${Utils.bytesToString(v.shuffleWrite)} / ${v.shuffleWriteRecords}"} + + {Utils.bytesToString(v.memoryBytesSpilled)} + + {Utils.bytesToString(v.diskBytesSpilled)} +
    {UIUtils.formatDuration(millis.toLong)}{Utils.bytesToString(d.toLong)}{Utils.bytesToString(d.toLong)}{s"${Utils.bytesToString(d.toLong)} / ${recordDist.next().toLong}"}InputInput Size / RecordsOutputOutput Size / Records + + Shuffle Read Blocked Time + + + + Shuffle Read Size / Records + + Shuffle Read (Remote) + + Shuffle Remote Reads + + Shuffle WriteShuffle Write Size / Records
    - {inputReadable} + {s"$inputReadable / $inputRecords"} - {outputReadable} + {s"$outputReadable / $outputRecords"} + {shuffleReadBlockedTimeReadable} + - {shuffleReadReadable} + {s"$shuffleReadReadable / $shuffleReadRecords"} + + {shuffleReadRemoteReadable} - {shuffleWriteReadable} + {s"$shuffleWriteReadable / $shuffleWriteRecords"}
    spark.master(none) - The cluster manager to connect to. See the list of - allowed master URL's. -
    spark.executor.memory512m - Amount of memory to use per executor process, in the same format as JVM memory strings - (e.g. 512m, 2g). -
    spark.driver.memory512mspark.driver.cores1 - Amount of memory to use for the driver process, i.e. where SparkContext is initialized. - (e.g. 512m, 2g). + Number of cores to use for the driver process, only in cluster mode.
    spark.driver.maxResultSize 1g @@ -130,38 +113,35 @@ of the most common options to set are:
    spark.serializerorg.apache.spark.serializer.
    JavaSerializer
    spark.driver.memory512m - Class to use for serializing objects that will be sent over the network or need to be cached - in serialized form. The default of Java serialization works with any Serializable Java object - but is quite slow, so we recommend using - org.apache.spark.serializer.KryoSerializer and configuring Kryo serialization - when speed is necessary. Can be any subclass of - - org.apache.spark.Serializer. + Amount of memory to use for the driver process, i.e. where SparkContext is initialized. + (e.g. 512m, 2g). + +
    Note: In client mode, this config must not be set through the SparkConf + directly in your application, because the driver JVM has already started at that point. + Instead, please set this through the --driver-memory command line option + or in your default properties file.
    spark.kryo.classesToRegister(none)spark.executor.memory512m - If you use Kryo serialization, give a comma-separated list of custom class names to register - with Kryo. - See the tuning guide for more details. + Amount of memory to use per executor process, in the same format as JVM memory strings + (e.g. 512m, 2g).
    spark.kryo.registratorspark.extraListeners (none) - If you use Kryo serialization, set this class to register your custom classes with Kryo. This - property is useful if you need to register your classes in a custom way, e.g. to specify a custom - field serializer. Otherwise spark.kryo.classesToRegister is simpler. It should be - set to a class that extends - - KryoRegistrator. - See the tuning guide for more details. + A comma-separated list of classes that implement SparkListener; when initializing + SparkContext, instances of these classes will be created and registered with Spark's listener + bus. If a class has a single-argument constructor that accepts a SparkConf, that constructor + will be called; otherwise, a zero-argument constructor will be called. If no valid constructor + can be found, the SparkContext creation will fail with an exception.
    spark.master(none) + The cluster manager to connect to. See the list of + allowed master URL's. +
    Apart from these, the following properties are also available, and may be useful in some situations: @@ -191,51 +179,84 @@ Apart from these, the following properties are also available, and may be useful - + + + + + + + + + + + + + + + + + + + - + - + - - + + @@ -249,30 +270,40 @@ Apart from these, the following properties are also available, and may be useful - + + + + + + - + - - + + @@ -283,6 +314,9 @@ Apart from these, the following properties are also available, and may be useful or it will be displayed before the driver exiting. It also can be dumped into disk by `sc.dump_profiles(path)`. If some of the profile results had been displayed maually, they will not be displayed automatically before driver exiting. + + By default the `pyspark.profiler.BasicProfiler` will be used, but this can be overridden by + passing a profiler class in as a parameter to the `SparkContext` constructor. @@ -295,6 +329,15 @@ Apart from these, the following properties are also available, and may be useful automatically. + + + + + @@ -305,40 +348,38 @@ Apart from these, the following properties are also available, and may be useful from JVM to Python worker for every task. +
    Property NameDefaultMeaning
    spark.executor.extraJavaOptionsspark.driver.extraClassPath(none) + Extra classpath entries to append to the classpath of the driver. + +
    Note: In client mode, this config must not be set through the SparkConf + directly in your application, because the driver JVM has already started at that point. + Instead, please set this through the --driver-class-path command line option or in + your default properties file.
    spark.driver.extraJavaOptions(none) + A string of extra JVM options to pass to the driver. For instance, GC settings or other logging. + +
    Note: In client mode, this config must not be set through the SparkConf + directly in your application, because the driver JVM has already started at that point. + Instead, please set this through the --driver-java-options command line option or in + your default properties file.
    spark.driver.extraLibraryPath (none) - A string of extra JVM options to pass to executors. For instance, GC settings or other - logging. Note that it is illegal to set Spark properties or heap size settings with this - option. Spark properties should be set using a SparkConf object or the - spark-defaults.conf file used with the spark-submit script. Heap size settings can be set - with spark.executor.memory. + Set a special library path to use when launching the driver JVM. + +
    Note: In client mode, this config must not be set through the SparkConf + directly in your application, because the driver JVM has already started at that point. + Instead, please set this through the --driver-library-path command line option or in + your default properties file.
    spark.driver.userClassPathFirstfalse + (Experimental) Whether to give user-added jars precedence over Spark's own jars when loading + classes in the the driver. This feature can be used to mitigate conflicts between Spark's + dependencies and user dependencies. It is currently an experimental feature. + + This is used in cluster mode only.
    spark.executor.extraClassPath (none) - Extra classpath entries to append to the classpath of executors. This exists primarily - for backwards-compatibility with older versions of Spark. Users typically should not need - to set this option. + Extra classpath entries to append to the classpath of executors. This exists primarily for + backwards-compatibility with older versions of Spark. Users typically should not need to set + this option.
    spark.executor.extraLibraryPathspark.executor.extraJavaOptions (none) - Set a special library path to use when launching executor JVM's. + A string of extra JVM options to pass to executors. For instance, GC settings or other logging. + Note that it is illegal to set Spark properties or heap size settings with this option. Spark + properties should be set using a SparkConf object or the spark-defaults.conf file used with the + spark-submit script. Heap size settings can be set with spark.executor.memory.
    spark.executor.logs.rolling.strategyspark.executor.extraLibraryPath (none) - Set the strategy of rolling of executor logs. By default it is disabled. It can - be set to "time" (time-based rolling) or "size" (size-based rolling). For "time", - use spark.executor.logs.rolling.time.interval to set the rolling interval. - For "size", use spark.executor.logs.rolling.size.maxBytes to set - the maximum file size for rolling. + Set a special library path to use when launching executor JVM's.
    spark.executor.logs.rolling.time.intervaldailyspark.executor.logs.rolling.maxRetainedFiles(none) - Set the time interval by which the executor logs will be rolled over. - Rolling is disabled by default. Valid values are `daily`, `hourly`, `minutely` or - any interval in seconds. See spark.executor.logs.rolling.maxRetainedFiles - for automatic cleaning of old logs. + Sets the number of latest rolling log files that are going to be retained by the system. + Older log files will be deleted. Disabled by default.
    spark.executor.logs.rolling.maxRetainedFilesspark.executor.logs.rolling.strategy (none) - Sets the number of latest rolling log files that are going to be retained by the system. - Older log files will be deleted. Disabled by default. + Set the strategy of rolling of executor logs. By default it is disabled. It can + be set to "time" (time-based rolling) or "size" (size-based rolling). For "time", + use spark.executor.logs.rolling.time.interval to set the rolling interval. + For "size", use spark.executor.logs.rolling.size.maxBytes to set + the maximum file size for rolling. +
    spark.executor.logs.rolling.time.intervaldaily + Set the time interval by which the executor logs will be rolled over. + Rolling is disabled by default. Valid values are `daily`, `hourly`, `minutely` or + any interval in seconds. See spark.executor.logs.rolling.maxRetainedFiles + for automatic cleaning of old logs.
    spark.files.userClassPathFirstspark.executor.userClassPathFirst false - (Experimental) Whether to give user-added jars precedence over Spark's own jars when - loading classes in Executors. This feature can be used to mitigate conflicts between - Spark's dependencies and user dependencies. It is currently an experimental feature. - (Currently, this setting does not work for YARN, see SPARK-2996 for more details). + (Experimental) Same functionality as spark.driver.userClassPathFirst, but + applied to executor instances.
    spark.python.worker.memory512mspark.executorEnv.[EnvironmentVariableName](none) - Amount of memory to use per python worker process during aggregation, in the same - format as JVM memory strings (e.g. 512m, 2g). If the memory - used during aggregation goes above this amount, it will spill the data into disks. + Add the environment variable specified by EnvironmentVariableName to the Executor + process. The user can specify multiple of these to set multiple environment variables.
    spark.python.worker.memory512m + Amount of memory to use per python worker process during aggregation, in the same + format as JVM memory strings (e.g. 512m, 2g). If the memory + used during aggregation goes above this amount, it will spill the data into disks. +
    spark.python.worker.reuse true
    + +#### Shuffle Behavior + + - - + + - - + + - - + + -
    Property NameDefaultMeaning
    spark.executorEnv.[EnvironmentVariableName](none)spark.reducer.maxMbInFlight48 - Add the environment variable specified by EnvironmentVariableName to the Executor - process. The user can specify multiple of these to set multiple environment variables. + Maximum size (in megabytes) of map outputs to fetch simultaneously from each reduce task. Since + each output requires us to create a buffer to receive it, this represents a fixed memory + overhead per reduce task, so keep it small unless you have a large amount of memory.
    spark.mesos.executor.homedriver side SPARK_HOMEspark.shuffle.blockTransferServicenetty - Set the directory in which Spark is installed on the executors in Mesos. By default, the - executors will simply use the driver's Spark home directory, which may not be visible to - them. Note that this is only relevant if a Spark binary package is not specified through - spark.executor.uri. + Implementation to use for transferring shuffle and cached blocks between executors. There + are two implementations available: netty and nio. Netty-based + block transfer is intended to be simpler but equally efficient and is the default option + starting in 1.2.
    spark.mesos.executor.memoryOverheadexecutor memory * 0.07, with minimum of 384spark.shuffle.compresstrue - This value is an additive for spark.executor.memory, specified in MiB, - which is used to calculate the total Mesos task memory. A value of 384 - implies a 384MiB overhead. Additionally, there is a hard-coded 7% minimum - overhead. The final overhead will be the larger of either - `spark.mesos.executor.memoryOverhead` or 7% of `spark.executor.memory`. + Whether to compress map output files. Generally a good idea. Compression will use + spark.io.compression.codec.
    - -#### Shuffle Behavior - - @@ -350,55 +391,46 @@ Apart from these, the following properties are also available, and may be useful - - + + - - + + - - + + - + - - - - - - - + + @@ -410,6 +442,17 @@ Apart from these, the following properties are also available, and may be useful the default option starting in 1.2. + + + + + @@ -419,13 +462,19 @@ Apart from these, the following properties are also available, and may be useful - - + + + + + + +
    Property NameDefaultMeaning
    spark.shuffle.consolidateFiles false
    spark.shuffle.spilltruespark.shuffle.file.buffer.kb32 - If set to "true", limits the amount of memory used during reduces by spilling data out to disk. - This spilling threshold is specified by spark.shuffle.memoryFraction. + Size of the in-memory buffer for each shuffle file output stream, in kilobytes. These buffers + reduce the number of disk seeks and system calls made in creating intermediate shuffle files.
    spark.shuffle.spill.compresstruespark.shuffle.io.maxRetries3 - Whether to compress data spilled during shuffles. Compression will use - spark.io.compression.codec. + (Netty only) Fetches that fail due to IO-related exceptions are automatically retried if this is + set to a non-zero value. This retry logic helps stabilize large shuffles in the face of long GC + pauses or transient network connectivity issues.
    spark.shuffle.memoryFraction0.2spark.shuffle.io.numConnectionsPerPeer1 - Fraction of Java heap to use for aggregation and cogroups during shuffles, if - spark.shuffle.spill is true. At any given time, the collective size of - all in-memory maps used for shuffles is bounded by this limit, beyond which the contents will - begin to spill to disk. If spills are often, consider increasing this value at the expense of - spark.storage.memoryFraction. + (Netty only) Connections between hosts are reused in order to reduce connection buildup for + large clusters. For clusters with many hard disks and few hosts, this may result in insufficient + concurrency to saturate all disks, and so users may consider increasing this value.
    spark.shuffle.compressspark.shuffle.io.preferDirectBufs true - Whether to compress map output files. Generally a good idea. Compression will use - spark.io.compression.codec. -
    spark.shuffle.file.buffer.kb32 - Size of the in-memory buffer for each shuffle file output stream, in kilobytes. These buffers - reduce the number of disk seeks and system calls made in creating intermediate shuffle files. + (Netty only) Off-heap buffers are used to reduce garbage collection during shuffle and cache + block transfer. For environments where off-heap memory is tightly limited, users may wish to + turn this off to force all allocations from Netty to be on-heap.
    spark.reducer.maxMbInFlight48spark.shuffle.io.retryWait5 - Maximum size (in megabytes) of map outputs to fetch simultaneously from each reduce task. Since - each output requires us to create a buffer to receive it, this represents a fixed memory - overhead per reduce task, so keep it small unless you have a large amount of memory. + (Netty only) Seconds to wait between retries of fetches. The maximum delay caused by retrying + is simply maxRetries * retryWait, by default 15 seconds.
    spark.shuffle.memoryFraction0.2 + Fraction of Java heap to use for aggregation and cogroups during shuffles, if + spark.shuffle.spill is true. At any given time, the collective size of + all in-memory maps used for shuffles is bounded by this limit, beyond which the contents will + begin to spill to disk. If spills are often, consider increasing this value at the expense of + spark.storage.memoryFraction. +
    spark.shuffle.sort.bypassMergeThreshold 200
    spark.shuffle.blockTransferServicenettyspark.shuffle.spilltrue - Implementation to use for transferring shuffle and cached blocks between executors. There - are two implementations available: netty and nio. Netty-based - block transfer is intended to be simpler but equally efficient and is the default option - starting in 1.2. + If set to "true", limits the amount of memory used during reduces by spilling data out to disk. + This spilling threshold is specified by spark.shuffle.memoryFraction. +
    spark.shuffle.spill.compresstrue + Whether to compress data spilled during shuffles. Compression will use + spark.io.compression.codec.
    @@ -434,26 +483,28 @@ Apart from these, the following properties are also available, and may be useful - - + + - - + + - + Base directory in which Spark events are logged, if spark.eventLog.enabled is true. + Within this base directory, Spark creates a sub-directory for each application, and logs the + events specific to the application in this directory. Users may want to set this to + a unified location like an HDFS directory so history files can be read by the history server. + + - - + + @@ -464,28 +515,26 @@ Apart from these, the following properties are also available, and may be useful - - + + - - + + - - + +
    Property NameDefaultMeaning
    spark.ui.port4040spark.eventLog.compressfalse - Port for your application's dashboard, which shows memory and workload data. + Whether to compress logged events, if spark.eventLog.enabled is true.
    spark.ui.retainedStages1000spark.eventLog.dirfile:///tmp/spark-events - How many stages the Spark UI and status APIs remember before garbage - collecting. -
    spark.ui.retainedJobs1000spark.eventLog.enabledfalse - How many jobs the Spark UI and status APIs remember before garbage - collecting. + Whether to log Spark events, useful for reconstructing the Web UI after the application has + finished.
    spark.eventLog.enabledfalsespark.ui.port4040 - Whether to log Spark events, useful for reconstructing the Web UI after the application has - finished. + Port for your application's dashboard, which shows memory and workload data.
    spark.eventLog.compressfalsespark.ui.retainedJobs1000 - Whether to compress logged events, if spark.eventLog.enabled is true. + How many jobs the Spark UI and status APIs remember before garbage + collecting.
    spark.eventLog.dirfile:///tmp/spark-eventsspark.ui.retainedStages1000 - Base directory in which Spark events are logged, if spark.eventLog.enabled is true. - Within this base directory, Spark creates a sub-directory for each application, and logs the - events specific to the application in this directory. Users may want to set this to - a unified location like an HDFS directory so history files can be read by the history server. + How many stages the Spark UI and status APIs remember before garbage + collecting.
    @@ -501,12 +550,10 @@ Apart from these, the following properties are also available, and may be useful - spark.rdd.compress - false + spark.closure.serializer + org.apache.spark.serializer.
    JavaSerializer - Whether to compress serialized RDD partitions (e.g. for - StorageLevel.MEMORY_ONLY_SER). Can save substantial space at the cost of some - extra CPU time. + Serializer class to use for closures. Currently only the Java serializer is supported. @@ -522,14 +569,6 @@ Apart from these, the following properties are also available, and may be useful and org.apache.spark.io.SnappyCompressionCodec. - - spark.io.compression.snappy.block.size - 32768 - - Block size (in bytes) used in Snappy compression, in the case when Snappy compression codec - is used. Lowering this block size will also lower shuffle memory usage when Snappy is used. - - spark.io.compression.lz4.block.size 32768 @@ -539,21 +578,20 @@ Apart from these, the following properties are also available, and may be useful - spark.closure.serializer - org.apache.spark.serializer.
    JavaSerializer + spark.io.compression.snappy.block.size + 32768 - Serializer class to use for closures. Currently only the Java serializer is supported. + Block size (in bytes) used in Snappy compression, in the case when Snappy compression codec + is used. Lowering this block size will also lower shuffle memory usage when Snappy is used. - spark.serializer.objectStreamReset - 100 + spark.kryo.classesToRegister + (none) - When serializing using org.apache.spark.serializer.JavaSerializer, the serializer caches - objects to prevent writing redundant data, however that stops garbage collection of those - objects. By calling 'reset' you flush that info from the serializer, and allow old - objects to be collected. To turn off this periodic reset set it to -1. - By default it will reset the serializer every 100 objects. + If you use Kryo serialization, give a comma-separated list of custom class names to register + with Kryo. + See the tuning guide for more details. @@ -578,12 +616,16 @@ Apart from these, the following properties are also available, and may be useful - spark.kryoserializer.buffer.mb - 0.064 + spark.kryo.registrator + (none) - Initial size of Kryo's serialization buffer, in megabytes. Note that there will be one buffer - per core on each worker. This buffer will grow up to - spark.kryoserializer.buffer.max.mb if needed. + If you use Kryo serialization, set this class to register your custom classes with Kryo. This + property is useful if you need to register your classes in a custom way, e.g. to specify a custom + field serializer. Otherwise spark.kryo.classesToRegister is simpler. It should be + set to a class that extends + + KryoRegistrator. + See the tuning guide for more details. @@ -595,11 +637,80 @@ Apart from these, the following properties are also available, and may be useful inside Kryo. + + spark.kryoserializer.buffer.mb + 0.064 + + Initial size of Kryo's serialization buffer, in megabytes. Note that there will be one buffer + per core on each worker. This buffer will grow up to + spark.kryoserializer.buffer.max.mb if needed. + + + + spark.rdd.compress + false + + Whether to compress serialized RDD partitions (e.g. for + StorageLevel.MEMORY_ONLY_SER). Can save substantial space at the cost of some + extra CPU time. + + + + spark.serializer + org.apache.spark.serializer.
    JavaSerializer + + Class to use for serializing objects that will be sent over the network or need to be cached + in serialized form. The default of Java serialization works with any Serializable Java object + but is quite slow, so we recommend using + org.apache.spark.serializer.KryoSerializer and configuring Kryo serialization + when speed is necessary. Can be any subclass of + + org.apache.spark.Serializer. + + + + spark.serializer.objectStreamReset + 100 + + When serializing using org.apache.spark.serializer.JavaSerializer, the serializer caches + objects to prevent writing redundant data, however that stops garbage collection of those + objects. By calling 'reset' you flush that info from the serializer, and allow old + objects to be collected. To turn off this periodic reset set it to -1. + By default it will reset the serializer every 100 objects. + + #### Execution Behavior + + + + + + + + + + + + + + + - - - + + + - - + + @@ -642,12 +752,23 @@ Apart from these, the following properties are also available, and may be useful - - - + + + + + + + + @@ -658,6 +779,15 @@ Apart from these, the following properties are also available, and may be useful increase it if you configure your own old generation size. + + + + + @@ -676,15 +806,6 @@ Apart from these, the following properties are also available, and may be useful directories on Tachyon file system. - - - - - @@ -692,106 +813,19 @@ Apart from these, the following properties are also available, and may be useful The URL of the underlying Tachyon file system in the TachyonStore. - - - - - - - - - - - - - - - - - - - -
    Property NameDefaultMeaning
    spark.broadcast.blockSize4096 + Size of each piece of a block in kilobytes for TorrentBroadcastFactory. + Too large a value decreases parallelism during broadcast (makes it slower); however, if it is + too small, BlockManager might take a performance hit. +
    spark.broadcast.factoryorg.apache.spark.broadcast.
    TorrentBroadcastFactory
    + Which broadcast implementation to use. +
    spark.cleaner.ttl(infinite) + Duration (seconds) of how long Spark will remember any metadata (stages generated, tasks + generated, etc.). Periodic cleanups will ensure that metadata older than this duration will be + forgotten. This is useful for running Spark for many hours / days (for example, running 24/7 in + case of Spark Streaming applications). Note that any RDD that persists in memory for more than + this duration will be cleared as well. +
    spark.default.parallelism @@ -618,19 +729,18 @@ Apart from these, the following properties are also available, and may be useful
    spark.broadcast.factoryorg.apache.spark.broadcast.
    TorrentBroadcastFactory
    - Which broadcast implementation to use. - spark.executor.heartbeatInterval10000Interval (milliseconds) between each executor's heartbeats to the driver. Heartbeats let + the driver know that the executor is still alive and update it with metrics for in-progress + tasks.
    spark.broadcast.blockSize4096spark.files.fetchTimeout60 - Size of each piece of a block in kilobytes for TorrentBroadcastFactory. - Too large a value decreases parallelism during broadcast (makes it slower); however, if it is - too small, BlockManager might take a performance hit. + Communication timeout to use when fetching files added through SparkContext.addFile() from + the driver, in seconds.
    spark.files.fetchTimeout60 - Communication timeout to use when fetching files added through SparkContext.addFile() from - the driver, in seconds. - spark.hadoop.cloneConffalseIf set to true, clones a new Hadoop Configuration object for each task. This + option should be enabled to work around Configuration thread-safety issues (see + SPARK-2546 for more details). + This is disabled by default in order to avoid unexpected performance regressions for jobs that + are not affected by these issues.
    spark.hadoop.validateOutputSpecstrueIf set to true, validates the output specification (e.g. checking if the output directory already exists) + used in saveAsHadoopFile and other variants. This can be disabled to silence exceptions due to pre-existing + output directories. We recommend that users do not disable this except if trying to achieve compatibility with + previous versions of Spark. Simply use Hadoop's FileSystem API to delete output directories by hand. + This setting is ignored for jobs generated through Spark Streaming's StreamingContext, since + data may need to be rewritten to pre-existing output directories during checkpoint recovery.
    spark.storage.memoryFraction
    spark.storage.memoryMapThreshold2097152 + Size of a block, in bytes, above which Spark memory maps when reading a block from disk. + This prevents Spark from memory mapping very small blocks. In general, memory + mapping has high overhead for blocks close to or below the page size of the operating system. +
    spark.storage.unrollFraction 0.2
    spark.storage.memoryMapThreshold2097152 - Size of a block, in bytes, above which Spark memory maps when reading a block from disk. - This prevents Spark from memory mapping very small blocks. In general, memory - mapping has high overhead for blocks close to or below the page size of the operating system. -
    spark.tachyonStore.url tachyon://localhost:19998
    spark.cleaner.ttl(infinite) - Duration (seconds) of how long Spark will remember any metadata (stages generated, tasks - generated, etc.). Periodic cleanups will ensure that metadata older than this duration will be - forgotten. This is useful for running Spark for many hours / days (for example, running 24/7 in - case of Spark Streaming applications). Note that any RDD that persists in memory for more than - this duration will be cleared as well. -
    spark.hadoop.validateOutputSpecstrueIf set to true, validates the output specification (e.g. checking if the output directory already exists) - used in saveAsHadoopFile and other variants. This can be disabled to silence exceptions due to pre-existing - output directories. We recommend that users do not disable this except if trying to achieve compatibility with - previous versions of Spark. Simply use Hadoop's FileSystem API to delete output directories by hand. - This setting is ignored for jobs generated through Spark Streaming's StreamingContext, since - data may need to be rewritten to pre-existing output directories during checkpoint recovery.
    spark.hadoop.cloneConffalseIf set to true, clones a new Hadoop Configuration object for each task. This - option should be enabled to work around Configuration thread-safety issues (see - SPARK-2546 for more details). - This is disabled by default in order to avoid unexpected performance regressions for jobs that - are not affected by these issues.
    spark.executor.heartbeatInterval10000Interval (milliseconds) between each executor's heartbeats to the driver. Heartbeats let - the driver know that the executor is still alive and update it with metrics for in-progress - tasks.
    #### Networking - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + @@ -804,181 +838,139 @@ Apart from these, the following properties are also available, and may be useful - - - - - - - - - - - - + + - - - - - - - - - - - - + + - - + + - - + + - - + + - -
    Property NameDefaultMeaning
    spark.driver.host(local hostname) - Hostname or IP address for the driver to listen on. - This is used for communicating with the executors and the standalone Master. -
    spark.driver.port(random) - Port for the driver to listen on. - This is used for communicating with the executors and the standalone Master. -
    spark.fileserver.port(random) - Port for the driver's HTTP file server to listen on. -
    spark.broadcast.port(random) - Port for the driver's HTTP broadcast server to listen on. - This is not relevant for torrent broadcast. -
    spark.replClassServer.port(random) - Port for the driver's HTTP class server to listen on. - This is only relevant for the Spark shell. -
    spark.blockManager.port(random) - Port for all block managers to listen on. These exist on both the driver and the executors. -
    spark.executor.port(random) - Port for the executor to listen on. This is used for communicating with the driver. -
    spark.port.maxRetries16spark.akka.failure-detector.threshold300.0 - Default maximum number of retries when binding to a port before giving up. + This is set to a larger value to disable failure detector that comes inbuilt akka. It can be + enabled again, if you plan to use this feature (Not recommended). This maps to akka's + `akka.remote.transport-failure-detector.threshold`. Tune this in combination of + `spark.akka.heartbeat.pauses` and `spark.akka.heartbeat.interval` if you need to.
    spark.akka.threads4 - Number of actor threads to use for communication. Can be useful to increase on large clusters - when the driver has a lot of CPU cores. -
    spark.akka.timeout100 - Communication timeout between Spark nodes, in seconds. -
    spark.network.timeout120spark.akka.heartbeat.interval1000 - Default timeout for all network interactions, in seconds. This config will be used in - place of spark.core.connection.ack.wait.timeout, spark.akka.timeout, - spark.storage.blockManagerSlaveTimeoutMs or - spark.shuffle.io.connectionTimeout, if they are not configured. + This is set to a larger value to disable the transport failure detector that comes built in to + Akka. It can be enabled again, if you plan to use this feature (Not recommended). A larger + interval value in seconds reduces network overhead and a smaller value ( ~ 1 s) might be more + informative for Akka's failure detector. Tune this in combination of `spark.akka.heartbeat.pauses` + if you need to. A likely positive use case for using failure detector would be: a sensistive + failure detector can help evict rogue executors quickly. However this is usually not the case + as GC pauses and network lags are expected in a real Spark cluster. Apart from that enabling + this leads to a lot of exchanges of heart beats between nodes leading to flooding the network + with those.
    spark.akka.heartbeat.pauses 6000 - This is set to a larger value to disable failure detector that comes inbuilt akka. It can be - enabled again, if you plan to use this feature (Not recommended). Acceptable heart beat pause - in seconds for akka. This can be used to control sensitivity to gc pauses. Tune this in - combination of `spark.akka.heartbeat.interval` and `spark.akka.failure-detector.threshold` - if you need to. -
    spark.akka.failure-detector.threshold300.0 - This is set to a larger value to disable failure detector that comes inbuilt akka. It can be - enabled again, if you plan to use this feature (Not recommended). This maps to akka's - `akka.remote.transport-failure-detector.threshold`. Tune this in combination of - `spark.akka.heartbeat.pauses` and `spark.akka.heartbeat.interval` if you need to. -
    spark.akka.heartbeat.interval1000 - This is set to a larger value to disable failure detector that comes inbuilt akka. It can be - enabled again, if you plan to use this feature (Not recommended). A larger interval value in - seconds reduces network overhead and a smaller value ( ~ 1 s) might be more informative for - akka's failure detector. Tune this in combination of `spark.akka.heartbeat.pauses` and - `spark.akka.failure-detector.threshold` if you need to. Only positive use case for using - failure detector can be, a sensistive failure detector can help evict rogue executors really - quick. However this is usually not the case as gc pauses and network lags are expected in a - real Spark cluster. Apart from that enabling this leads to a lot of exchanges of heart beats - between nodes leading to flooding the network with those. + This is set to a larger value to disable the transport failure detector that comes built in to Akka. + It can be enabled again, if you plan to use this feature (Not recommended). Acceptable heart + beat pause in seconds for Akka. This can be used to control sensitivity to GC pauses. Tune + this along with `spark.akka.heartbeat.interval` if you need to.
    spark.shuffle.io.preferDirectBufstruespark.akka.threads4 - (Netty only) Off-heap buffers are used to reduce garbage collection during shuffle and cache - block transfer. For environments where off-heap memory is tightly limited, users may wish to - turn this off to force all allocations from Netty to be on-heap. + Number of actor threads to use for communication. Can be useful to increase on large clusters + when the driver has a lot of CPU cores.
    spark.shuffle.io.numConnectionsPerPeer1spark.akka.timeout100 - (Netty only) Connections between hosts are reused in order to reduce connection buildup for - large clusters. For clusters with many hard disks and few hosts, this may result in insufficient - concurrency to saturate all disks, and so users may consider increasing this value. + Communication timeout between Spark nodes, in seconds.
    spark.shuffle.io.maxRetries3spark.blockManager.port(random) - (Netty only) Fetches that fail due to IO-related exceptions are automatically retried if this is - set to a non-zero value. This retry logic helps stabilize large shuffles in the face of long GC - pauses or transient network connectivity issues. + Port for all block managers to listen on. These exist on both the driver and the executors.
    spark.shuffle.io.retryWait5spark.broadcast.port(random) - (Netty only) Seconds to wait between retries of fetches. The maximum delay caused by retrying - is simply maxRetries * retryWait, by default 15 seconds. + Port for the driver's HTTP broadcast server to listen on. + This is not relevant for torrent broadcast.
    - -#### Scheduling - - + - - + + - - + + - - + + - - + + - - + + - - + + - - + + +
    Property NameDefaultMeaning
    spark.task.cpus1spark.driver.host(local hostname) - Number of cores to allocate for each task. + Hostname or IP address for the driver to listen on. + This is used for communicating with the executors and the standalone Master.
    spark.task.maxFailures4spark.driver.port(random) - Number of individual task failures before giving up on the job. - Should be greater than or equal to 1. Number of allowed retries = this value - 1. + Port for the driver to listen on. + This is used for communicating with the executors and the standalone Master.
    spark.scheduler.modeFIFOspark.executor.port(random) - The scheduling mode between - jobs submitted to the same SparkContext. Can be set to FAIR - to use fair sharing instead of queueing jobs one after another. Useful for - multi-user services. + Port for the executor to listen on. This is used for communicating with the driver.
    spark.cores.max(not set)spark.fileserver.port(random) - When running on a standalone deploy cluster or a - Mesos cluster in "coarse-grained" - sharing mode, the maximum amount of CPU cores to request for the application from - across the cluster (not from each machine). If not set, the default will be - spark.deploy.defaultCores on Spark's standalone cluster manager, or - infinite (all available cores) on Mesos. + Port for the driver's HTTP file server to listen on.
    spark.mesos.coarsefalsespark.network.timeout120 - If set to "true", runs over Mesos clusters in - "coarse-grained" sharing mode, - where Spark acquires one long-lived Mesos task on each machine instead of one Mesos task per - Spark task. This gives lower-latency scheduling for short queries, but leaves resources in use - for the whole duration of the Spark job. + Default timeout for all network interactions, in seconds. This config will be used in + place of spark.core.connection.ack.wait.timeout, spark.akka.timeout, + spark.storage.blockManagerSlaveTimeoutMs or + spark.shuffle.io.connectionTimeout, if they are not configured.
    spark.speculationfalsespark.port.maxRetries16 - If set to "true", performs speculative execution of tasks. This means if one or more tasks are - running slowly in a stage, they will be re-launched. + Default maximum number of retries when binding to a port before giving up.
    spark.speculation.interval100spark.replClassServer.port(random) - How often Spark will check for tasks to speculate, in milliseconds. + Port for the driver's HTTP class server to listen on. + This is only relevant for the Spark shell.
    + +#### Scheduling + + - - + + - - + + @@ -994,19 +986,19 @@ Apart from these, the following properties are also available, and may be useful - + - + @@ -1017,14 +1009,14 @@ Apart from these, the following properties are also available, and may be useful - - + + - + - - + + + + + + + - + + + + + + + + + + + + + + + + + + + + + + + + + +
    Property NameDefaultMeaning
    spark.speculation.quantile0.75spark.cores.max(not set) - Percentage of tasks which must be complete before speculation is enabled for a particular stage. + When running on a standalone deploy cluster or a + Mesos cluster in "coarse-grained" + sharing mode, the maximum amount of CPU cores to request for the application from + across the cluster (not from each machine). If not set, the default will be + spark.deploy.defaultCores on Spark's standalone cluster manager, or + infinite (all available cores) on Mesos.
    spark.speculation.multiplier1.5spark.localExecution.enabledfalse - How many times slower a task is than the median to be considered for speculation. + Enables Spark to run certain jobs, such as first() or take() on the driver, without sending + tasks to the cluster. This can make certain jobs execute very quickly, but may require + shipping a whole partition of data to the driver.
    spark.locality.wait.processspark.locality.wait.node spark.locality.wait - Customize the locality wait for process locality. This affects tasks that attempt to access - cached data in a particular executor process. + Customize the locality wait for node locality. For example, you can set this to 0 to skip + node locality and search immediately for rack locality (if your cluster has rack information).
    spark.locality.wait.nodespark.locality.wait.process spark.locality.wait - Customize the locality wait for node locality. For example, you can set this to 0 to skip - node locality and search immediately for rack locality (if your cluster has rack information). + Customize the locality wait for process locality. This affects tasks that attempt to access + cached data in a particular executor process.
    spark.scheduler.revive.interval1000spark.scheduler.maxRegisteredResourcesWaitingTime30000 - The interval length for the scheduler to revive the worker resource offers to run tasks + Maximum amount of time to wait for resources to register before scheduling begins (in milliseconds).
    spark.scheduler.minRegisteredResourcesRatio 0.0 for Mesos and Standalone mode, 0.8 for YARN @@ -1037,25 +1029,70 @@ Apart from these, the following properties are also available, and may be useful
    spark.scheduler.maxRegisteredResourcesWaitingTime30000spark.scheduler.modeFIFO - Maximum amount of time to wait for resources to register before scheduling begins + The scheduling mode between + jobs submitted to the same SparkContext. Can be set to FAIR + to use fair sharing instead of queueing jobs one after another. Useful for + multi-user services. +
    spark.scheduler.revive.interval1000 + The interval length for the scheduler to revive the worker resource offers to run tasks (in milliseconds).
    spark.localExecution.enabledspark.speculation false - Enables Spark to run certain jobs, such as first() or take() on the driver, without sending - tasks to the cluster. This can make certain jobs execute very quickly, but may require - shipping a whole partition of data to the driver. + If set to "true", performs speculative execution of tasks. This means if one or more tasks are + running slowly in a stage, they will be re-launched. +
    spark.speculation.interval100 + How often Spark will check for tasks to speculate, in milliseconds. +
    spark.speculation.multiplier1.5 + How many times slower a task is than the median to be considered for speculation. +
    spark.speculation.quantile0.75 + Percentage of tasks which must be complete before speculation is enabled for a particular stage. +
    spark.task.cpus1 + Number of cores to allocate for each task. +
    spark.task.maxFailures4 + Number of individual task failures before giving up on the job. + Should be greater than or equal to 1. Number of allowed retries = this value - 1.
    -#### Dynamic allocation +#### Dynamic Allocation @@ -1067,29 +1104,46 @@ Apart from these, the following properties are also available, and may be useful available on YARN mode. For more detail, see the description here.

    - This requires the following configurations to be set: + This requires spark.shuffle.service.enabled to be set. + The following configurations are also relevant: spark.dynamicAllocation.minExecutors, spark.dynamicAllocation.maxExecutors, and - spark.shuffle.service.enabled + spark.dynamicAllocation.initialExecutors + + + + + + + - - + + + + + + - + - - - - -
    Property NameDefaultMeaning
    spark.dynamicAllocation.executorIdleTimeout600 + If dynamic allocation is enabled and an executor has been idle for more than this duration + (in seconds), the executor will be removed. For more detail, see this + description.
    spark.dynamicAllocation.initialExecutors spark.dynamicAllocation.minExecutors(none) - Lower bound for the number of executors if dynamic allocation is enabled (required). + Initial number of executors to run if dynamic allocation is enabled.
    spark.dynamicAllocation.maxExecutors(none)Integer.MAX_VALUE + Upper bound for the number of executors if dynamic allocation is enabled. +
    spark.dynamicAllocation.minExecutors0 - Upper bound for the number of executors if dynamic allocation is enabled (required). + Lower bound for the number of executors if dynamic allocation is enabled.
    spark.dynamicAllocation.schedulerBacklogTimeout605 If dynamic allocation is enabled and there have been pending tasks backlogged for more than this duration (in seconds), new executors will be requested. For more detail, see this @@ -1105,20 +1159,30 @@ Apart from these, the following properties are also available, and may be useful description.
    spark.dynamicAllocation.executorIdleTimeout600 - If dynamic allocation is enabled and an executor has been idle for more than this duration - (in seconds), the executor will be removed. For more detail, see this - description. -
    #### Security + + + + + + + + + + @@ -1135,6 +1199,15 @@ Apart from these, the following properties are also available, and may be useful not running on YARN and authentication is enabled. + + + + + @@ -1144,12 +1217,11 @@ Apart from these, the following properties are also available, and may be useful - - + + @@ -1166,16 +1238,6 @@ Apart from these, the following properties are also available, and may be useful -Dspark.com.test.filter1.params='param1=foo,param2=testing' - - - - - @@ -1184,25 +1246,88 @@ Apart from these, the following properties are also available, and may be useful user that started the Spark job has view access. - - - - - - - - - -
    Property NameDefaultMeaning
    spark.acls.enablefalse + Whether Spark acls should are enabled. If enabled, this checks to see if the user has + access permissions to view or modify the job. Note this requires the user to be known, + so if the user comes across as null no checks are done. Filters can be used with the UI + to authenticate and set the user. +
    spark.admin.aclsEmpty + Comma separated list of users/administrators that have view and modify access to all Spark jobs. + This can be used if you run on a shared cluster and have a set of administrators or devs who + help debug when things work. +
    spark.authenticate false
    spark.core.connection.ack.wait.timeout60 + Number of seconds for the connection to wait for ack to occur before timing + out and giving up. To avoid unwilling timeout caused by long pause like GC, + you can set larger value. +
    spark.core.connection.auth.wait.timeout 30
    spark.core.connection.ack.wait.timeout60spark.modify.aclsEmpty - Number of seconds for the connection to wait for ack to occur before timing - out and giving up. To avoid unwilling timeout caused by long pause like GC, - you can set larger value. + Comma separated list of users that have modify access to the Spark job. By default only the + user that started the Spark job has access to modify it (kill it for example).
    spark.acls.enablefalse - Whether Spark acls should are enabled. If enabled, this checks to see if the user has - access permissions to view or modify the job. Note this requires the user to be known, - so if the user comes across as null no checks are done. Filters can be used with the UI - to authenticate and set the user. -
    spark.ui.view.acls Empty
    spark.modify.aclsEmpty - Comma separated list of users that have modify access to the Spark job. By default only the - user that started the Spark job has access to modify it (kill it for example). -
    spark.admin.aclsEmpty - Comma separated list of users/administrators that have view and modify access to all Spark jobs. - This can be used if you run on a shared cluster and have a set of administrators or devs who - help debug when things work. -
    +#### Encryption + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
    Property NameDefaultMeaning
    spark.ssl.enabledfalse +

    Whether to enable SSL connections on all supported protocols.

    + +

    All the SSL settings like spark.ssl.xxx where xxx is a + particular configuration property, denote the global configuration for all the supported + protocols. In order to override the global configuration for the particular protocol, + the properties must be overwritten in the protocol-specific namespace.

    + +

    Use spark.ssl.YYY.XXX settings to overwrite the global configuration for + particular protocol denoted by YYY. Currently YYY can be + either akka for Akka based connections or fs for broadcast and + file server.

    +
    spark.ssl.enabledAlgorithmsEmpty + A comma separated list of ciphers. The specified ciphers must be supported by JVM. + The reference list of protocols one can find on + this + page. +
    spark.ssl.keyPasswordNone + A password to the private key in key-store. +
    spark.ssl.keyStoreNone + A path to a key-store file. The path can be absolute or relative to the directory where + the component is started in. +
    spark.ssl.keyStorePasswordNone + A password to the key-store. +
    spark.ssl.protocolNone + A protocol name. The protocol must be supported by JVM. The reference list of protocols + one can find on this + page. +
    spark.ssl.trustStoreNone + A path to a trust-store file. The path can be absolute or relative to the directory + where the component is started in. +
    spark.ssl.trustStorePasswordNone + A password to the trust-store. +
    + + #### Spark Streaming diff --git a/docs/ec2-scripts.md b/docs/ec2-scripts.md index d50f445d7ecc7..8c9a1e1262d8f 100644 --- a/docs/ec2-scripts.md +++ b/docs/ec2-scripts.md @@ -52,7 +52,7 @@ identify machines belonging to each cluster in the Amazon EC2 Console. ```bash export AWS_SECRET_ACCESS_KEY=AaBbCcDdEeFGgHhIiJjKkLlMmNnOoPpQqRrSsTtU export AWS_ACCESS_KEY_ID=ABCDEFG1234567890123 -./spark-ec2 --key-pair=awskey --identity-file=awskey.pem --region=us-west-1 --zone=us-west-1a --spark-version=1.1.0 launch my-spark-cluster +./spark-ec2 --key-pair=awskey --identity-file=awskey.pem --region=us-west-1 --zone=us-west-1a launch my-spark-cluster ``` - After everything launches, check that the cluster scheduler is up and sees diff --git a/docs/graphx-programming-guide.md b/docs/graphx-programming-guide.md index e298c51f8a5b7..28bdf81ca0ca5 100644 --- a/docs/graphx-programming-guide.md +++ b/docs/graphx-programming-guide.md @@ -1,6 +1,8 @@ --- layout: global -title: GraphX Programming Guide +displayTitle: GraphX Programming Guide +title: GraphX +description: GraphX graph processing library guide for Spark SPARK_VERSION_SHORT --- * This will become a table of contents (this text will be scraped). @@ -536,7 +538,7 @@ val joinedGraph = graph.joinVertices(uniqueCosts, ## Neighborhood Aggregation -A key step in may graph analytics tasks is aggregating information about the neighborhood of each +A key step in many graph analytics tasks is aggregating information about the neighborhood of each vertex. For example, we might want to know the number of followers each user has or the average age of the the followers of each user. Many iterative graph algorithms (e.g., PageRank, Shortest Path, and @@ -632,7 +634,7 @@ avgAgeOfOlderFollowers.collect.foreach(println(_)) ### Map Reduce Triplets Transition Guide (Legacy) -In earlier versions of GraphX we neighborhood aggregation was accomplished using the +In earlier versions of GraphX neighborhood aggregation was accomplished using the [`mapReduceTriplets`][Graph.mapReduceTriplets] operator: {% highlight scala %} @@ -680,8 +682,8 @@ val result = graph.aggregateMessages[String](msgFun, reduceFun) ### Computing Degree Information A common aggregation task is computing the degree of each vertex: the number of edges adjacent to -each vertex. In the context of directed graphs it often necessary to know the in-degree, out- -degree, and the total degree of each vertex. The [`GraphOps`][GraphOps] class contains a +each vertex. In the context of directed graphs it is often necessary to know the in-degree, +out-degree, and the total degree of each vertex. The [`GraphOps`][GraphOps] class contains a collection of operators to compute the degrees of each vertex. For example in the following we compute the max in, out, and total degrees: diff --git a/docs/index.md b/docs/index.md index 171d6ddad62f3..e006be640e582 100644 --- a/docs/index.md +++ b/docs/index.md @@ -1,6 +1,8 @@ --- layout: global -title: Spark Overview +displayTitle: Spark Overview +title: Overview +description: Apache Spark SPARK_VERSION_SHORT documentation homepage --- Apache Spark is a fast and general-purpose cluster computing system. diff --git a/docs/job-scheduling.md b/docs/job-scheduling.md index a5425eb3557b2..5295e351dd711 100644 --- a/docs/job-scheduling.md +++ b/docs/job-scheduling.md @@ -77,11 +77,10 @@ scheduling while sharing cluster resources efficiently. ### Configuration and Setup All configurations used by this feature live under the `spark.dynamicAllocation.*` namespace. -To enable this feature, your application must set `spark.dynamicAllocation.enabled` to `true` and -provide lower and upper bounds for the number of executors through -`spark.dynamicAllocation.minExecutors` and `spark.dynamicAllocation.maxExecutors`. Other relevant -configurations are described on the [configurations page](configuration.html#dynamic-allocation) -and in the subsequent sections in detail. +To enable this feature, your application must set `spark.dynamicAllocation.enabled` to `true`. +Other relevant configurations are described on the +[configurations page](configuration.html#dynamic-allocation) and in the subsequent sections in +detail. Additionally, your application must use an external shuffle service. The purpose of the service is to preserve the shuffle files written by executors so the executors can be safely removed (more diff --git a/docs/ml-guide.md b/docs/ml-guide.md index be178d7689fdd..da6aef7f14c4c 100644 --- a/docs/ml-guide.md +++ b/docs/ml-guide.md @@ -23,13 +23,13 @@ to `spark.ml`. Spark ML standardizes APIs for machine learning algorithms to make it easier to combine multiple algorithms into a single pipeline, or workflow. This section covers the key concepts introduced by the Spark ML API. -* **[ML Dataset](ml-guide.html#ml-dataset)**: Spark ML uses the [`SchemaRDD`](api/scala/index.html#org.apache.spark.sql.SchemaRDD) from Spark SQL as a dataset which can hold a variety of data types. +* **[ML Dataset](ml-guide.html#ml-dataset)**: Spark ML uses the [`DataFrame`](api/scala/index.html#org.apache.spark.sql.DataFrame) from Spark SQL as a dataset which can hold a variety of data types. E.g., a dataset could have different columns storing text, feature vectors, true labels, and predictions. -* **[`Transformer`](ml-guide.html#transformers)**: A `Transformer` is an algorithm which can transform one `SchemaRDD` into another `SchemaRDD`. +* **[`Transformer`](ml-guide.html#transformers)**: A `Transformer` is an algorithm which can transform one `DataFrame` into another `DataFrame`. E.g., an ML model is a `Transformer` which transforms an RDD with features into an RDD with predictions. -* **[`Estimator`](ml-guide.html#estimators)**: An `Estimator` is an algorithm which can be fit on a `SchemaRDD` to produce a `Transformer`. +* **[`Estimator`](ml-guide.html#estimators)**: An `Estimator` is an algorithm which can be fit on a `DataFrame` to produce a `Transformer`. E.g., a learning algorithm is an `Estimator` which trains on a dataset and produces a model. * **[`Pipeline`](ml-guide.html#pipeline)**: A `Pipeline` chains multiple `Transformer`s and `Estimator`s together to specify an ML workflow. @@ -39,20 +39,20 @@ E.g., a learning algorithm is an `Estimator` which trains on a dataset and produ ## ML Dataset Machine learning can be applied to a wide variety of data types, such as vectors, text, images, and structured data. -Spark ML adopts the [`SchemaRDD`](api/scala/index.html#org.apache.spark.sql.SchemaRDD) from Spark SQL in order to support a variety of data types under a unified Dataset concept. +Spark ML adopts the [`DataFrame`](api/scala/index.html#org.apache.spark.sql.DataFrame) from Spark SQL in order to support a variety of data types under a unified Dataset concept. -`SchemaRDD` supports many basic and structured types; see the [Spark SQL datatype reference](sql-programming-guide.html#spark-sql-datatype-reference) for a list of supported types. -In addition to the types listed in the Spark SQL guide, `SchemaRDD` can use ML [`Vector`](api/scala/index.html#org.apache.spark.mllib.linalg.Vector) types. +`DataFrame` supports many basic and structured types; see the [Spark SQL datatype reference](sql-programming-guide.html#spark-sql-datatype-reference) for a list of supported types. +In addition to the types listed in the Spark SQL guide, `DataFrame` can use ML [`Vector`](api/scala/index.html#org.apache.spark.mllib.linalg.Vector) types. -A `SchemaRDD` can be created either implicitly or explicitly from a regular `RDD`. See the code examples below and the [Spark SQL programming guide](sql-programming-guide.html) for examples. +A `DataFrame` can be created either implicitly or explicitly from a regular `RDD`. See the code examples below and the [Spark SQL programming guide](sql-programming-guide.html) for examples. -Columns in a `SchemaRDD` are named. The code examples below use names such as "text," "features," and "label." +Columns in a `DataFrame` are named. The code examples below use names such as "text," "features," and "label." ## ML Algorithms ### Transformers -A [`Transformer`](api/scala/index.html#org.apache.spark.ml.Transformer) is an abstraction which includes feature transformers and learned models. Technically, a `Transformer` implements a method `transform()` which converts one `SchemaRDD` into another, generally by appending one or more columns. +A [`Transformer`](api/scala/index.html#org.apache.spark.ml.Transformer) is an abstraction which includes feature transformers and learned models. Technically, a `Transformer` implements a method `transform()` which converts one `DataFrame` into another, generally by appending one or more columns. For example: * A feature transformer might take a dataset, read a column (e.g., text), convert it into a new column (e.g., feature vectors), append the new column to the dataset, and output the updated dataset. @@ -60,7 +60,7 @@ For example: ### Estimators -An [`Estimator`](api/scala/index.html#org.apache.spark.ml.Estimator) abstracts the concept of a learning algorithm or any algorithm which fits or trains on data. Technically, an `Estimator` implements a method `fit()` which accepts a `SchemaRDD` and produces a `Transformer`. +An [`Estimator`](api/scala/index.html#org.apache.spark.ml.Estimator) abstracts the concept of a learning algorithm or any algorithm which fits or trains on data. Technically, an `Estimator` implements a method `fit()` which accepts a `DataFrame` and produces a `Transformer`. For example, a learning algorithm such as `LogisticRegression` is an `Estimator`, and calling `fit()` trains a `LogisticRegressionModel`, which is a `Transformer`. ### Properties of ML Algorithms @@ -101,7 +101,7 @@ We illustrate this for the simple text document workflow. The figure below is f Above, the top row represents a `Pipeline` with three stages. The first two (`Tokenizer` and `HashingTF`) are `Transformer`s (blue), and the third (`LogisticRegression`) is an `Estimator` (red). -The bottom row represents data flowing through the pipeline, where cylinders indicate `SchemaRDD`s. +The bottom row represents data flowing through the pipeline, where cylinders indicate `DataFrame`s. The `Pipeline.fit()` method is called on the original dataset which has raw text documents and labels. The `Tokenizer.transform()` method splits the raw text documents into words, adding a new column with words into the dataset. The `HashingTF.transform()` method converts the words column into feature vectors, adding a new column with those vectors to the dataset. @@ -130,7 +130,7 @@ Each stage's `transform()` method updates the dataset and passes it to the next *DAG `Pipeline`s*: A `Pipeline`'s stages are specified as an ordered array. The examples given here are all for linear `Pipeline`s, i.e., `Pipeline`s in which each stage uses data produced by the previous stage. It is possible to create non-linear `Pipeline`s as long as the data flow graph forms a Directed Acyclic Graph (DAG). This graph is currently specified implicitly based on the input and output column names of each stage (generally specified as parameters). If the `Pipeline` forms a DAG, then the stages must be specified in topological order. -*Runtime checking*: Since `Pipeline`s can operate on datasets with varied types, they cannot use compile-time type checking. `Pipeline`s and `PipelineModel`s instead do runtime checking before actually running the `Pipeline`. This type checking is done using the dataset *schema*, a description of the data types of columns in the `SchemaRDD`. +*Runtime checking*: Since `Pipeline`s can operate on datasets with varied types, they cannot use compile-time type checking. `Pipeline`s and `PipelineModel`s instead do runtime checking before actually running the `Pipeline`. This type checking is done using the dataset *schema*, a description of the data types of columns in the `DataFrame`. ## Parameters @@ -171,12 +171,12 @@ import org.apache.spark.sql.{Row, SQLContext} val conf = new SparkConf().setAppName("SimpleParamsExample") val sc = new SparkContext(conf) val sqlContext = new SQLContext(sc) -import sqlContext._ +import sqlContext.implicits._ // Prepare training data. // We use LabeledPoint, which is a case class. Spark SQL can convert RDDs of case classes -// into SchemaRDDs, where it uses the case class metadata to infer the schema. -val training = sparkContext.parallelize(Seq( +// into DataFrames, where it uses the case class metadata to infer the schema. +val training = sc.parallelize(Seq( LabeledPoint(1.0, Vectors.dense(0.0, 1.1, 0.1)), LabeledPoint(0.0, Vectors.dense(2.0, 1.0, -1.0)), LabeledPoint(0.0, Vectors.dense(2.0, 1.3, 1.0)), @@ -192,7 +192,7 @@ lr.setMaxIter(10) .setRegParam(0.01) // Learn a LogisticRegression model. This uses the parameters stored in lr. -val model1 = lr.fit(training) +val model1 = lr.fit(training.toDF) // Since model1 is a Model (i.e., a Transformer produced by an Estimator), // we can view the parameters it used during fit(). // This prints the parameter (name: value) pairs, where names are unique IDs for this @@ -203,33 +203,35 @@ println("Model 1 was fit using parameters: " + model1.fittingParamMap) // which supports several methods for specifying parameters. val paramMap = ParamMap(lr.maxIter -> 20) paramMap.put(lr.maxIter, 30) // Specify 1 Param. This overwrites the original maxIter. -paramMap.put(lr.regParam -> 0.1, lr.threshold -> 0.5) // Specify multiple Params. +paramMap.put(lr.regParam -> 0.1, lr.threshold -> 0.55) // Specify multiple Params. // One can also combine ParamMaps. -val paramMap2 = ParamMap(lr.scoreCol -> "probability") // Changes output column name. +val paramMap2 = ParamMap(lr.probabilityCol -> "myProbability") // Change output column name val paramMapCombined = paramMap ++ paramMap2 // Now learn a new model using the paramMapCombined parameters. // paramMapCombined overrides all parameters set earlier via lr.set* methods. -val model2 = lr.fit(training, paramMapCombined) +val model2 = lr.fit(training.toDF, paramMapCombined) println("Model 2 was fit using parameters: " + model2.fittingParamMap) -// Prepare test documents. -val test = sparkContext.parallelize(Seq( +// Prepare test data. +val test = sc.parallelize(Seq( LabeledPoint(1.0, Vectors.dense(-1.0, 1.5, 1.3)), LabeledPoint(0.0, Vectors.dense(3.0, 2.0, -0.1)), LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5)))) -// Make predictions on test documents using the Transformer.transform() method. +// Make predictions on test data using the Transformer.transform() method. // LogisticRegression.transform will only use the 'features' column. -// Note that model2.transform() outputs a 'probability' column instead of the usual 'score' -// column since we renamed the lr.scoreCol parameter previously. -model2.transform(test) - .select('features, 'label, 'probability, 'prediction) +// Note that model2.transform() outputs a 'myProbability' column instead of the usual +// 'probability' column since we renamed the lr.probabilityCol parameter previously. +model2.transform(test.toDF) + .select("features", "label", "myProbability", "prediction") .collect() - .foreach { case Row(features: Vector, label: Double, prob: Double, prediction: Double) => - println("(" + features + ", " + label + ") -> prob=" + prob + ", prediction=" + prediction) + .foreach { case Row(features: Vector, label: Double, prob: Vector, prediction: Double) => + println("($features, $label) -> prob=$prob, prediction=$prediction") } + +sc.stop() {% endhighlight %} @@ -244,23 +246,23 @@ import org.apache.spark.ml.param.ParamMap; import org.apache.spark.ml.classification.LogisticRegression; import org.apache.spark.mllib.linalg.Vectors; import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.sql.api.java.JavaSQLContext; -import org.apache.spark.sql.api.java.JavaSchemaRDD; -import org.apache.spark.sql.api.java.Row; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.Row; SparkConf conf = new SparkConf().setAppName("JavaSimpleParamsExample"); JavaSparkContext jsc = new JavaSparkContext(conf); -JavaSQLContext jsql = new JavaSQLContext(jsc); +SQLContext jsql = new SQLContext(jsc); // Prepare training data. -// We use LabeledPoint, which is a case class. Spark SQL can convert RDDs of case classes -// into SchemaRDDs, where it uses the case class metadata to infer the schema. +// We use LabeledPoint, which is a JavaBean. Spark SQL can convert RDDs of JavaBeans +// into DataFrames, where it uses the bean metadata to infer the schema. List localTraining = Lists.newArrayList( new LabeledPoint(1.0, Vectors.dense(0.0, 1.1, 0.1)), new LabeledPoint(0.0, Vectors.dense(2.0, 1.0, -1.0)), new LabeledPoint(0.0, Vectors.dense(2.0, 1.3, 1.0)), new LabeledPoint(1.0, Vectors.dense(0.0, 1.2, -0.5))); -JavaSchemaRDD training = jsql.applySchema(jsc.parallelize(localTraining), LabeledPoint.class); +DataFrame training = jsql.createDataFrame(jsc.parallelize(localTraining), LabeledPoint.class); // Create a LogisticRegression instance. This instance is an Estimator. LogisticRegression lr = new LogisticRegression(); @@ -281,13 +283,13 @@ System.out.println("Model 1 was fit using parameters: " + model1.fittingParamMap // We may alternatively specify parameters using a ParamMap. ParamMap paramMap = new ParamMap(); -paramMap.put(lr.maxIter(), 20); // Specify 1 Param. +paramMap.put(lr.maxIter().w(20)); // Specify 1 Param. paramMap.put(lr.maxIter(), 30); // This overwrites the original maxIter. -paramMap.put(lr.regParam(), 0.1); +paramMap.put(lr.regParam().w(0.1), lr.threshold().w(0.55)); // Specify multiple Params. // One can also combine ParamMaps. ParamMap paramMap2 = new ParamMap(); -paramMap2.put(lr.scoreCol(), "probability"); // Changes output column name. +paramMap2.put(lr.probabilityCol().w("myProbability")); // Change output column name ParamMap paramMapCombined = paramMap.$plus$plus(paramMap2); // Now learn a new model using the paramMapCombined parameters. @@ -300,19 +302,19 @@ List localTest = Lists.newArrayList( new LabeledPoint(1.0, Vectors.dense(-1.0, 1.5, 1.3)), new LabeledPoint(0.0, Vectors.dense(3.0, 2.0, -0.1)), new LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5))); -JavaSchemaRDD test = jsql.applySchema(jsc.parallelize(localTest), LabeledPoint.class); +DataFrame test = jsql.createDataFrame(jsc.parallelize(localTest), LabeledPoint.class); // Make predictions on test documents using the Transformer.transform() method. // LogisticRegression.transform will only use the 'features' column. -// Note that model2.transform() outputs a 'probability' column instead of the usual 'score' -// column since we renamed the lr.scoreCol parameter previously. -model2.transform(test).registerAsTable("results"); -JavaSchemaRDD results = - jsql.sql("SELECT features, label, probability, prediction FROM results"); -for (Row r: results.collect()) { +// Note that model2.transform() outputs a 'myProbability' column instead of the usual +// 'probability' column since we renamed the lr.probabilityCol parameter previously. +DataFrame results = model2.transform(test); +for (Row r: results.select("features", "label", "myProbability", "prediction").collect()) { System.out.println("(" + r.get(0) + ", " + r.get(1) + ") -> prob=" + r.get(2) + ", prediction=" + r.get(3)); } + +jsc.stop(); {% endhighlight %} @@ -330,6 +332,7 @@ import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.ml.Pipeline import org.apache.spark.ml.classification.LogisticRegression import org.apache.spark.ml.feature.{HashingTF, Tokenizer} +import org.apache.spark.mllib.linalg.Vector import org.apache.spark.sql.{Row, SQLContext} // Labeled and unlabeled instance types. @@ -337,14 +340,14 @@ import org.apache.spark.sql.{Row, SQLContext} case class LabeledDocument(id: Long, text: String, label: Double) case class Document(id: Long, text: String) -// Set up contexts. Import implicit conversions to SchemaRDD from sqlContext. +// Set up contexts. Import implicit conversions to DataFrame from sqlContext. val conf = new SparkConf().setAppName("SimpleTextClassificationPipeline") val sc = new SparkContext(conf) val sqlContext = new SQLContext(sc) -import sqlContext._ +import sqlContext.implicits._ // Prepare training documents, which are labeled. -val training = sparkContext.parallelize(Seq( +val training = sc.parallelize(Seq( LabeledDocument(0L, "a b c d e spark", 1.0), LabeledDocument(1L, "b d", 0.0), LabeledDocument(2L, "spark f g h", 1.0), @@ -365,30 +368,32 @@ val pipeline = new Pipeline() .setStages(Array(tokenizer, hashingTF, lr)) // Fit the pipeline to training documents. -val model = pipeline.fit(training) +val model = pipeline.fit(training.toDF) // Prepare test documents, which are unlabeled. -val test = sparkContext.parallelize(Seq( +val test = sc.parallelize(Seq( Document(4L, "spark i j k"), Document(5L, "l m n"), Document(6L, "mapreduce spark"), Document(7L, "apache hadoop"))) // Make predictions on test documents. -model.transform(test) - .select('id, 'text, 'score, 'prediction) +model.transform(test.toDF) + .select("id", "text", "probability", "prediction") .collect() - .foreach { case Row(id: Long, text: String, score: Double, prediction: Double) => - println("(" + id + ", " + text + ") --> score=" + score + ", prediction=" + prediction) + .foreach { case Row(id: Long, text: String, prob: Vector, prediction: Double) => + println("($id, $text) --> prob=$prob, prediction=$prediction") } + +sc.stop() {% endhighlight %}
    {% highlight java %} -import java.io.Serializable; import java.util.List; import com.google.common.collect.Lists; +import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.ml.Pipeline; import org.apache.spark.ml.PipelineModel; @@ -396,10 +401,9 @@ import org.apache.spark.ml.PipelineStage; import org.apache.spark.ml.classification.LogisticRegression; import org.apache.spark.ml.feature.HashingTF; import org.apache.spark.ml.feature.Tokenizer; -import org.apache.spark.sql.api.java.JavaSQLContext; -import org.apache.spark.sql.api.java.JavaSchemaRDD; -import org.apache.spark.sql.api.java.Row; -import org.apache.spark.SparkConf; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SQLContext; // Labeled and unlabeled instance types. // Spark SQL can infer schema from Java Beans. @@ -434,7 +438,7 @@ public class LabeledDocument extends Document implements Serializable { // Set up contexts. SparkConf conf = new SparkConf().setAppName("JavaSimpleTextClassificationPipeline"); JavaSparkContext jsc = new JavaSparkContext(conf); -JavaSQLContext jsql = new JavaSQLContext(jsc); +SQLContext jsql = new SQLContext(jsc); // Prepare training documents, which are labeled. List localTraining = Lists.newArrayList( @@ -442,8 +446,7 @@ List localTraining = Lists.newArrayList( new LabeledDocument(1L, "b d", 0.0), new LabeledDocument(2L, "spark f g h", 1.0), new LabeledDocument(3L, "hadoop mapreduce", 0.0)); -JavaSchemaRDD training = - jsql.applySchema(jsc.parallelize(localTraining), LabeledDocument.class); +DataFrame training = jsql.createDataFrame(jsc.parallelize(localTraining), LabeledDocument.class); // Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr. Tokenizer tokenizer = new Tokenizer() @@ -468,16 +471,62 @@ List localTest = Lists.newArrayList( new Document(5L, "l m n"), new Document(6L, "mapreduce spark"), new Document(7L, "apache hadoop")); -JavaSchemaRDD test = - jsql.applySchema(jsc.parallelize(localTest), Document.class); +DataFrame test = jsql.createDataFrame(jsc.parallelize(localTest), Document.class); // Make predictions on test documents. -model.transform(test).registerAsTable("prediction"); -JavaSchemaRDD predictions = jsql.sql("SELECT id, text, score, prediction FROM prediction"); -for (Row r: predictions.collect()) { - System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> score=" + r.get(2) +DataFrame predictions = model.transform(test); +for (Row r: predictions.select("id", "text", "probability", "prediction").collect()) { + System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> prob=" + r.get(2) + ", prediction=" + r.get(3)); } + +jsc.stop(); +{% endhighlight %} +
    + +
    +{% highlight python %} +from pyspark import SparkContext +from pyspark.ml import Pipeline +from pyspark.ml.classification import LogisticRegression +from pyspark.ml.feature import HashingTF, Tokenizer +from pyspark.sql import Row, SQLContext + +sc = SparkContext(appName="SimpleTextClassificationPipeline") +sqlCtx = SQLContext(sc) + +# Prepare training documents, which are labeled. +LabeledDocument = Row("id", "text", "label") +training = sc.parallelize([(0L, "a b c d e spark", 1.0), + (1L, "b d", 0.0), + (2L, "spark f g h", 1.0), + (3L, "hadoop mapreduce", 0.0)]) \ + .map(lambda x: LabeledDocument(*x)).toDF() + +# Configure an ML pipeline, which consists of tree stages: tokenizer, hashingTF, and lr. +tokenizer = Tokenizer(inputCol="text", outputCol="words") +hashingTF = HashingTF(inputCol=tokenizer.getOutputCol(), outputCol="features") +lr = LogisticRegression(maxIter=10, regParam=0.01) +pipeline = Pipeline(stages=[tokenizer, hashingTF, lr]) + +# Fit the pipeline to training documents. +model = pipeline.fit(training) + +# Prepare test documents, which are unlabeled. +Document = Row("id", "text") +test = sc.parallelize([(4L, "spark i j k"), + (5L, "l m n"), + (6L, "mapreduce spark"), + (7L, "apache hadoop")]) \ + .map(lambda x: Document(*x)).toDF() + +# Make predictions on test documents and print columns of interest. +prediction = model.transform(test) +selected = prediction.select("id", "text", "prediction") +for row in selected.collect(): + print row + +sc.stop() {% endhighlight %}
    @@ -508,21 +557,21 @@ However, it is also a well-established method for choosing parameters which is m
    {% highlight scala %} import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.SparkContext._ import org.apache.spark.ml.Pipeline import org.apache.spark.ml.classification.LogisticRegression import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator import org.apache.spark.ml.feature.{HashingTF, Tokenizer} import org.apache.spark.ml.tuning.{ParamGridBuilder, CrossValidator} +import org.apache.spark.mllib.linalg.Vector import org.apache.spark.sql.{Row, SQLContext} val conf = new SparkConf().setAppName("CrossValidatorExample") val sc = new SparkContext(conf) val sqlContext = new SQLContext(sc) -import sqlContext._ +import sqlContext.implicits._ // Prepare training documents, which are labeled. -val training = sparkContext.parallelize(Seq( +val training = sc.parallelize(Seq( LabeledDocument(0L, "a b c d e spark", 1.0), LabeledDocument(1L, "b d", 0.0), LabeledDocument(2L, "spark f g h", 1.0), @@ -565,24 +614,24 @@ crossval.setEstimatorParamMaps(paramGrid) crossval.setNumFolds(2) // Use 3+ in practice // Run cross-validation, and choose the best set of parameters. -val cvModel = crossval.fit(training) -// Get the best LogisticRegression model (with the best set of parameters from paramGrid). -val lrModel = cvModel.bestModel +val cvModel = crossval.fit(training.toDF) // Prepare test documents, which are unlabeled. -val test = sparkContext.parallelize(Seq( +val test = sc.parallelize(Seq( Document(4L, "spark i j k"), Document(5L, "l m n"), Document(6L, "mapreduce spark"), Document(7L, "apache hadoop"))) // Make predictions on test documents. cvModel uses the best model found (lrModel). -cvModel.transform(test) - .select('id, 'text, 'score, 'prediction) +cvModel.transform(test.toDF) + .select("id", "text", "probability", "prediction") .collect() - .foreach { case Row(id: Long, text: String, score: Double, prediction: Double) => - println("(" + id + ", " + text + ") --> score=" + score + ", prediction=" + prediction) + .foreach { case Row(id: Long, text: String, prob: Vector, prediction: Double) => + println(s"($id, $text) --> prob=$prob, prediction=$prediction") } + +sc.stop() {% endhighlight %}
    @@ -592,7 +641,6 @@ import java.util.List; import com.google.common.collect.Lists; import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.ml.Model; import org.apache.spark.ml.Pipeline; import org.apache.spark.ml.PipelineStage; import org.apache.spark.ml.classification.LogisticRegression; @@ -603,13 +651,13 @@ import org.apache.spark.ml.param.ParamMap; import org.apache.spark.ml.tuning.CrossValidator; import org.apache.spark.ml.tuning.CrossValidatorModel; import org.apache.spark.ml.tuning.ParamGridBuilder; -import org.apache.spark.sql.api.java.JavaSQLContext; -import org.apache.spark.sql.api.java.JavaSchemaRDD; -import org.apache.spark.sql.api.java.Row; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SQLContext; SparkConf conf = new SparkConf().setAppName("JavaCrossValidatorExample"); JavaSparkContext jsc = new JavaSparkContext(conf); -JavaSQLContext jsql = new JavaSQLContext(jsc); +SQLContext jsql = new SQLContext(jsc); // Prepare training documents, which are labeled. List localTraining = Lists.newArrayList( @@ -625,8 +673,7 @@ List localTraining = Lists.newArrayList( new LabeledDocument(9L, "a e c l", 0.0), new LabeledDocument(10L, "spark compile", 1.0), new LabeledDocument(11L, "hadoop software", 0.0)); -JavaSchemaRDD training = - jsql.applySchema(jsc.parallelize(localTraining), LabeledDocument.class); +DataFrame training = jsql.createDataFrame(jsc.parallelize(localTraining), LabeledDocument.class); // Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr. Tokenizer tokenizer = new Tokenizer() @@ -660,8 +707,6 @@ crossval.setNumFolds(2); // Use 3+ in practice // Run cross-validation, and choose the best set of parameters. CrossValidatorModel cvModel = crossval.fit(training); -// Get the best LogisticRegression model (with the best set of parameters from paramGrid). -Model lrModel = cvModel.bestModel(); // Prepare test documents, which are unlabeled. List localTest = Lists.newArrayList( @@ -669,15 +714,16 @@ List localTest = Lists.newArrayList( new Document(5L, "l m n"), new Document(6L, "mapreduce spark"), new Document(7L, "apache hadoop")); -JavaSchemaRDD test = jsql.applySchema(jsc.parallelize(localTest), Document.class); +DataFrame test = jsql.createDataFrame(jsc.parallelize(localTest), Document.class); // Make predictions on test documents. cvModel uses the best model found (lrModel). -cvModel.transform(test).registerAsTable("prediction"); -JavaSchemaRDD predictions = jsql.sql("SELECT id, text, score, prediction FROM prediction"); -for (Row r: predictions.collect()) { - System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> score=" + r.get(2) +DataFrame predictions = cvModel.transform(test); +for (Row r: predictions.select("id", "text", "probability", "prediction").collect()) { + System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> prob=" + r.get(2) + ", prediction=" + r.get(3)); } + +jsc.stop(); {% endhighlight %} @@ -686,6 +732,21 @@ for (Row r: predictions.collect()) { # Dependencies Spark ML currently depends on MLlib and has the same dependencies. -Please see the [MLlib Dependencies guide](mllib-guide.html#Dependencies) for more info. +Please see the [MLlib Dependencies guide](mllib-guide.html#dependencies) for more info. Spark ML also depends upon Spark SQL, but the relevant parts of Spark SQL do not bring additional dependencies. + +# Migration Guide + +## From 1.2 to 1.3 + +The main API changes are from Spark SQL. We list the most important changes here: + +* The old [SchemaRDD](http://spark.apache.org/docs/1.2.1/api/scala/index.html#org.apache.spark.sql.SchemaRDD) has been replaced with [DataFrame](api/scala/index.html#org.apache.spark.sql.DataFrame) with a somewhat modified API. All algorithms in Spark ML which used to use SchemaRDD now use DataFrame. +* In Spark 1.2, we used implicit conversions from `RDD`s of `LabeledPoint` into `SchemaRDD`s by calling `import sqlContext._` where `sqlContext` was an instance of `SQLContext`. These implicits have been moved, so we now call `import sqlContext.implicits._`. +* Java APIs for SQL have also changed accordingly. Please see the examples above and the [Spark SQL Programming Guide](sql-programming-guide.html) for details. + +Other changes were in `LogisticRegression`: + +* The `scoreCol` output column (with default value "score") was renamed to be `probabilityCol` (with default value "probability"). The type was originally `Double` (for the probability of class 1.0), but it is now `Vector` (for the probability of each class, to support multiclass classification in the future). +* In Spark 1.2, `LogisticRegressionModel` did not include an intercept. In Spark 1.3, it includes an intercept; however, it will always be 0.0 since it uses the default settings for [spark.mllib.LogisticRegressionWithLBFGS](api/scala/index.html#org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS). The option to use an intercept will be added in the future. diff --git a/docs/mllib-classification-regression.md b/docs/mllib-classification-regression.md index 719cc95767b00..8e91d62f4a907 100644 --- a/docs/mllib-classification-regression.md +++ b/docs/mllib-classification-regression.md @@ -17,13 +17,13 @@ the supported algorithms for each type of problem.
    - + - + - +
    Property NameDefaultMeaning
    Binary Classificationlinear SVMs, logistic regression, decision trees, naive BayesBinary Classificationlinear SVMs, logistic regression, decision trees, random forests, gradient-boosted trees, naive Bayes
    Multiclass Classificationdecision trees, naive BayesMulticlass Classificationdecision trees, random forests, naive Bayes
    Regressionlinear least squares, Lasso, ridge regression, decision treesRegressionlinear least squares, Lasso, ridge regression, decision trees, random forests, gradient-boosted trees, isotonic regression
    @@ -34,4 +34,8 @@ More details for these methods can be found here: * [binary classification (SVMs, logistic regression)](mllib-linear-methods.html#binary-classification) * [linear regression (least squares, Lasso, ridge)](mllib-linear-methods.html#linear-least-squares-lasso-and-ridge-regression) * [Decision trees](mllib-decision-tree.html) +* [Ensembles of decision trees](mllib-ensembles.html) + * [random forests](mllib-ensembles.html#random-forests) + * [gradient-boosted trees](mllib-ensembles.html#gradient-boosted-trees-gbts) * [Naive Bayes](mllib-naive-bayes.html) +* [Isotonic regression](mllib-isotonic-regression.html) diff --git a/docs/mllib-clustering.md b/docs/mllib-clustering.md index c696ae9c8e8c8..0b6db4fcb7b1f 100644 --- a/docs/mllib-clustering.md +++ b/docs/mllib-clustering.md @@ -4,25 +4,25 @@ title: Clustering - MLlib displayTitle: MLlib - Clustering --- -* Table of contents -{:toc} - - -## Clustering - Clustering is an unsupervised learning problem whereby we aim to group subsets of entities with one another based on some notion of similarity. Clustering is often used for exploratory analysis and/or as a component of a hierarchical supervised learning pipeline (in which distinct classifiers or regression -models are trained for each cluster). +models are trained for each cluster). + +MLlib supports the following models: -MLlib supports -[k-means](http://en.wikipedia.org/wiki/K-means_clustering) clustering, one of -the most commonly used clustering algorithms that clusters the data points into +* Table of contents +{:toc} + +## K-means + +[k-means](http://en.wikipedia.org/wiki/K-means_clustering) is one of the +most commonly used clustering algorithms that clusters the data points into a predefined number of clusters. The MLlib implementation includes a parallelized variant of the [k-means++](http://en.wikipedia.org/wiki/K-means%2B%2B) method called [kmeans||](http://theory.stanford.edu/~sergei/papers/vldb12-kmpar.pdf). -The implementation in MLlib has the following parameters: +The implementation in MLlib has the following parameters: * *k* is the number of desired clusters. * *maxIterations* is the maximum number of iterations to run. @@ -32,9 +32,9 @@ initialization via k-means\|\|. guaranteed to find a globally optimal solution, and when run multiple times on a given dataset, the algorithm returns the best clustering result). * *initializationSteps* determines the number of steps in the k-means\|\| algorithm. -* *epsilon* determines the distance threshold within which we consider k-means to have converged. +* *epsilon* determines the distance threshold within which we consider k-means to have converged. -### Examples +**Examples**
    @@ -148,41 +148,370 @@ print("Within Set Sum of Squared Error = " + str(WSSSE))
    -In order to run the above application, follow the instructions -provided in the [Self-Contained Applications](quick-start.html#self-contained-applications) -section of the Spark -Quick Start guide. Be sure to also include *spark-mllib* to your build file as -a dependency. +## Gaussian mixture + +A [Gaussian Mixture Model](http://en.wikipedia.org/wiki/Mixture_model#Multivariate_Gaussian_mixture_model) +represents a composite distribution whereby points are drawn from one of *k* Gaussian sub-distributions, +each with its own probability. The MLlib implementation uses the +[expectation-maximization](http://en.wikipedia.org/wiki/Expectation%E2%80%93maximization_algorithm) + algorithm to induce the maximum-likelihood model given a set of samples. The implementation +has the following parameters: + +* *k* is the number of desired clusters. +* *convergenceTol* is the maximum change in log-likelihood at which we consider convergence achieved. +* *maxIterations* is the maximum number of iterations to perform without reaching convergence. +* *initialModel* is an optional starting point from which to start the EM algorithm. If this parameter is omitted, a random starting point will be constructed from the data. + +**Examples** + +
    +
    +In the following example after loading and parsing data, we use a +[GaussianMixture](api/scala/index.html#org.apache.spark.mllib.clustering.GaussianMixture) +object to cluster the data into two clusters. The number of desired clusters is passed +to the algorithm. We then output the parameters of the mixture model. + +{% highlight scala %} +import org.apache.spark.mllib.clustering.GaussianMixture +import org.apache.spark.mllib.linalg.Vectors + +// Load and parse the data +val data = sc.textFile("data/mllib/gmm_data.txt") +val parsedData = data.map(s => Vectors.dense(s.trim.split(' ').map(_.toDouble))).cache() + +// Cluster the data into two classes using GaussianMixture +val gmm = new GaussianMixture().setK(2).run(parsedData) + +// output parameters of max-likelihood model +for (i <- 0 until gmm.k) { + println("weight=%f\nmu=%s\nsigma=\n%s\n" format + (gmm.weights(i), gmm.gaussians(i).mu, gmm.gaussians(i).sigma)) +} + +{% endhighlight %} +
    + +
    +All of MLlib's methods use Java-friendly types, so you can import and call them there the same +way you do in Scala. The only caveat is that the methods take Scala RDD objects, while the +Spark Java API uses a separate `JavaRDD` class. You can convert a Java RDD to a Scala one by +calling `.rdd()` on your `JavaRDD` object. A self-contained application example +that is equivalent to the provided example in Scala is given below: + +{% highlight java %} +import org.apache.spark.api.java.*; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.mllib.clustering.GaussianMixture; +import org.apache.spark.mllib.clustering.GaussianMixtureModel; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.SparkConf; + +public class GaussianMixtureExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("GaussianMixture Example"); + JavaSparkContext sc = new JavaSparkContext(conf); + + // Load and parse data + String path = "data/mllib/gmm_data.txt"; + JavaRDD data = sc.textFile(path); + JavaRDD parsedData = data.map( + new Function() { + public Vector call(String s) { + String[] sarray = s.trim().split(" "); + double[] values = new double[sarray.length]; + for (int i = 0; i < sarray.length; i++) + values[i] = Double.parseDouble(sarray[i]); + return Vectors.dense(values); + } + } + ); + parsedData.cache(); + + // Cluster the data into two classes using GaussianMixture + GaussianMixtureModel gmm = new GaussianMixture().setK(2).run(parsedData.rdd()); + + // Output the parameters of the mixture model + for(int j=0; j + +
    +In the following example after loading and parsing data, we use a +[GaussianMixture](api/python/pyspark.mllib.html#pyspark.mllib.clustering.GaussianMixture) +object to cluster the data into two clusters. The number of desired clusters is passed +to the algorithm. We then output the parameters of the mixture model. + +{% highlight python %} +from pyspark.mllib.clustering import GaussianMixture +from numpy import array + +# Load and parse the data +data = sc.textFile("data/mllib/gmm_data.txt") +parsedData = data.map(lambda line: array([float(x) for x in line.strip().split(' ')])) + +# Build the model (cluster the data) +gmm = GaussianMixture.train(parsedData, 2) + +# output parameters of model +for i in range(2): + print ("weight = ", gmm.weights[i], "mu = ", gmm.gaussians[i].mu, + "sigma = ", gmm.gaussians[i].sigma.toArray()) + +{% endhighlight %} +
    + +
    + +## Power iteration clustering (PIC) -## Streaming clustering +Power iteration clustering (PIC) is a scalable and efficient algorithm for clustering vertices of a +graph given pairwise similarties as edge properties, +described in [Lin and Cohen, Power Iteration Clustering](http://www.icml2010.org/papers/387.pdf). +It computes a pseudo-eigenvector of the normalized affinity matrix of the graph via +[power iteration](http://en.wikipedia.org/wiki/Power_iteration) and uses it to cluster vertices. +MLlib includes an implementation of PIC using GraphX as its backend. +It takes an `RDD` of `(srcId, dstId, similarity)` tuples and outputs a model with the clustering assignments. +The similarities must be nonnegative. +PIC assumes that the similarity measure is symmetric. +A pair `(srcId, dstId)` regardless of the ordering should appear at most once in the input data. +If a pair is missing from input, their similarity is treated as zero. +MLlib's PIC implementation takes the following (hyper-)parameters: -When data arrive in a stream, we may want to estimate clusters dynamically, -updating them as new data arrive. MLlib provides support for streaming k-means clustering, -with parameters to control the decay (or "forgetfulness") of the estimates. The algorithm -uses a generalization of the mini-batch k-means update rule. For each batch of data, we assign +* `k`: number of clusters +* `maxIterations`: maximum number of power iterations +* `initializationMode`: initialization model. This can be either "random", which is the default, + to use a random vector as vertex properties, or "degree" to use normalized sum similarities. + +**Examples** + +In the following, we show code snippets to demonstrate how to use PIC in MLlib. + +
    +
    + +[`PowerIterationClustering`](api/scala/index.html#org.apache.spark.mllib.clustering.PowerIterationClustering) +implements the PIC algorithm. +It takes an `RDD` of `(srcId: Long, dstId: Long, similarity: Double)` tuples representing the +affinity matrix. +Calling `PowerIterationClustering.run` returns a +[`PowerIterationClusteringModel`](api/scala/index.html#org.apache.spark.mllib.clustering.PowerIterationClusteringModel), +which contains the computed clustering assignments. + +{% highlight scala %} +import org.apache.spark.mllib.clustering.PowerIterationClustering +import org.apache.spark.mllib.linalg.Vectors + +val similarities: RDD[(Long, Long, Double)] = ... + +val pic = new PowerIteartionClustering() + .setK(3) + .setMaxIterations(20) +val model = pic.run(similarities) + +model.assignments.foreach { a => + println(s"${a.id} -> ${a.cluster}") +} +{% endhighlight %} + +A full example that produces the experiment described in the PIC paper can be found under +[`examples/`](https://github.com/apache/spark/blob/master/examples/src/main/scala/org/apache/spark/examples/mllib/PowerIterationClusteringExample.scala). + +
    + +
    + +[`PowerIterationClustering`](api/java/org/apache/spark/mllib/clustering/PowerIterationClustering.html) +implements the PIC algorithm. +It takes an `JavaRDD` of `(srcId: Long, dstId: Long, similarity: Double)` tuples representing the +affinity matrix. +Calling `PowerIterationClustering.run` returns a +[`PowerIterationClusteringModel`](api/java/org/apache/spark/mllib/clustering/PowerIterationClusteringModel.html) +which contains the computed clustering assignments. + +{% highlight java %} +import scala.Tuple2; +import scala.Tuple3; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.mllib.clustering.PowerIterationClustering; +import org.apache.spark.mllib.clustering.PowerIterationClusteringModel; + +JavaRDD> similarities = ... + +PowerIterationClustering pic = new PowerIterationClustering() + .setK(2) + .setMaxIterations(10); +PowerIterationClusteringModel model = pic.run(similarities); + +for (PowerIterationClustering.Assignment a: model.assignments().toJavaRDD().collect()) { + System.out.println(a.id() + " -> " + a.cluster()); +} +{% endhighlight %} +
    + +
    + +## Latent Dirichlet allocation (LDA) + +[Latent Dirichlet allocation (LDA)](http://en.wikipedia.org/wiki/Latent_Dirichlet_allocation) +is a topic model which infers topics from a collection of text documents. +LDA can be thought of as a clustering algorithm as follows: + +* Topics correspond to cluster centers, and documents correspond to examples (rows) in a dataset. +* Topics and documents both exist in a feature space, where feature vectors are vectors of word counts. +* Rather than estimating a clustering using a traditional distance, LDA uses a function based + on a statistical model of how text documents are generated. + +LDA takes in a collection of documents as vectors of word counts. +It learns clustering using [expectation-maximization](http://en.wikipedia.org/wiki/Expectation%E2%80%93maximization_algorithm) +on the likelihood function. After fitting on the documents, LDA provides: + +* Topics: Inferred topics, each of which is a probability distribution over terms (words). +* Topic distributions for documents: For each document in the training set, LDA gives a probability distribution over topics. + +LDA takes the following parameters: + +* `k`: Number of topics (i.e., cluster centers) +* `maxIterations`: Limit on the number of iterations of EM used for learning +* `docConcentration`: Hyperparameter for prior over documents' distributions over topics. Currently must be > 1, where larger values encourage smoother inferred distributions. +* `topicConcentration`: Hyperparameter for prior over topics' distributions over terms (words). Currently must be > 1, where larger values encourage smoother inferred distributions. +* `checkpointInterval`: If using checkpointing (set in the Spark configuration), this parameter specifies the frequency with which checkpoints will be created. If `maxIterations` is large, using checkpointing can help reduce shuffle file sizes on disk and help with failure recovery. + +*Note*: LDA is a new feature with some missing functionality. In particular, it does not yet +support prediction on new documents, and it does not have a Python API. These will be added in the future. + +**Examples** + +In the following example, we load word count vectors representing a corpus of documents. +We then use [LDA](api/scala/index.html#org.apache.spark.mllib.clustering.LDA) +to infer three topics from the documents. The number of desired clusters is passed +to the algorithm. We then output the topics, represented as probability distributions over words. + +
    +
    + +{% highlight scala %} +import org.apache.spark.mllib.clustering.LDA +import org.apache.spark.mllib.linalg.Vectors + +// Load and parse the data +val data = sc.textFile("data/mllib/sample_lda_data.txt") +val parsedData = data.map(s => Vectors.dense(s.trim.split(' ').map(_.toDouble))) +// Index documents with unique IDs +val corpus = parsedData.zipWithIndex.map(_.swap).cache() + +// Cluster the documents into three topics using LDA +val ldaModel = new LDA().setK(3).run(corpus) + +// Output topics. Each is a distribution over words (matching word count vectors) +println("Learned topics (as distributions over vocab of " + ldaModel.vocabSize + " words):") +val topics = ldaModel.topicsMatrix +for (topic <- Range(0, 3)) { + print("Topic " + topic + ":") + for (word <- Range(0, ldaModel.vocabSize)) { print(" " + topics(word, topic)); } + println() +} +{% endhighlight %} +
    + +
    +{% highlight java %} +import scala.Tuple2; + +import org.apache.spark.api.java.*; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.mllib.clustering.DistributedLDAModel; +import org.apache.spark.mllib.clustering.LDA; +import org.apache.spark.mllib.linalg.Matrix; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.SparkConf; + +public class JavaLDAExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("LDA Example"); + JavaSparkContext sc = new JavaSparkContext(conf); + + // Load and parse the data + String path = "data/mllib/sample_lda_data.txt"; + JavaRDD data = sc.textFile(path); + JavaRDD parsedData = data.map( + new Function() { + public Vector call(String s) { + String[] sarray = s.trim().split(" "); + double[] values = new double[sarray.length]; + for (int i = 0; i < sarray.length; i++) + values[i] = Double.parseDouble(sarray[i]); + return Vectors.dense(values); + } + } + ); + // Index documents with unique IDs + JavaPairRDD corpus = JavaPairRDD.fromJavaRDD(parsedData.zipWithIndex().map( + new Function, Tuple2>() { + public Tuple2 call(Tuple2 doc_id) { + return doc_id.swap(); + } + } + )); + corpus.cache(); + + // Cluster the documents into three topics using LDA + DistributedLDAModel ldaModel = new LDA().setK(3).run(corpus); + + // Output topics. Each is a distribution over words (matching word count vectors) + System.out.println("Learned topics (as distributions over vocab of " + ldaModel.vocabSize() + + " words):"); + Matrix topics = ldaModel.topicsMatrix(); + for (int topic = 0; topic < 3; topic++) { + System.out.print("Topic " + topic + ":"); + for (int word = 0; word < ldaModel.vocabSize(); word++) { + System.out.print(" " + topics.apply(word, topic)); + } + System.out.println(); + } + } +} +{% endhighlight %} +
    + +
    + +## Streaming k-means + +When data arrive in a stream, we may want to estimate clusters dynamically, +updating them as new data arrive. MLlib provides support for streaming k-means clustering, +with parameters to control the decay (or "forgetfulness") of the estimates. The algorithm +uses a generalization of the mini-batch k-means update rule. For each batch of data, we assign all points to their nearest cluster, compute new cluster centers, then update each cluster using: `\begin{equation} c_{t+1} = \frac{c_tn_t\alpha + x_tm_t}{n_t\alpha+m_t} \end{equation}` `\begin{equation} - n_{t+1} = n_t + m_t + n_{t+1} = n_t + m_t \end{equation}` -Where `$c_t$` is the previous center for the cluster, `$n_t$` is the number of points assigned -to the cluster thus far, `$x_t$` is the new cluster center from the current batch, and `$m_t$` -is the number of points added to the cluster in the current batch. The decay factor `$\alpha$` -can be used to ignore the past: with `$\alpha$=1` all data will be used from the beginning; -with `$\alpha$=0` only the most recent data will be used. This is analogous to an -exponentially-weighted moving average. +Where `$c_t$` is the previous center for the cluster, `$n_t$` is the number of points assigned +to the cluster thus far, `$x_t$` is the new cluster center from the current batch, and `$m_t$` +is the number of points added to the cluster in the current batch. The decay factor `$\alpha$` +can be used to ignore the past: with `$\alpha$=1` all data will be used from the beginning; +with `$\alpha$=0` only the most recent data will be used. This is analogous to an +exponentially-weighted moving average. -The decay can be specified using a `halfLife` parameter, which determines the +The decay can be specified using a `halfLife` parameter, which determines the correct decay factor `a` such that, for data acquired at time `t`, its contribution by time `t + halfLife` will have dropped to 0.5. The unit of time can be specified either as `batches` or `points` and the update rule will be adjusted accordingly. -### Examples +**Examples** This example shows how to estimate clusters on streaming data. @@ -200,9 +529,9 @@ import org.apache.spark.mllib.clustering.StreamingKMeans {% endhighlight %} -Then we make an input stream of vectors for training, as well as a stream of labeled data -points for testing. We assume a StreamingContext `ssc` has been created, see -[Spark Streaming Programming Guide](streaming-programming-guide.html#initializing) for more info. +Then we make an input stream of vectors for training, as well as a stream of labeled data +points for testing. We assume a StreamingContext `ssc` has been created, see +[Spark Streaming Programming Guide](streaming-programming-guide.html#initializing) for more info. {% highlight scala %} @@ -224,24 +553,24 @@ val model = new StreamingKMeans() {% endhighlight %} -Now register the streams for training and testing and start the job, printing +Now register the streams for training and testing and start the job, printing the predicted cluster assignments on new data points as they arrive. {% highlight scala %} model.trainOn(trainingData) -model.predictOnValues(testData).print() +model.predictOnValues(testData.map(lp => (lp.label, lp.features))).print() ssc.start() ssc.awaitTermination() - + {% endhighlight %} -As you add new text files with data the cluster centers will update. Each training +As you add new text files with data the cluster centers will update. Each training point should be formatted as `[x1, x2, x3]`, and each test data point -should be formatted as `(y, [x1, x2, x3])`, where `y` is some useful label or identifier -(e.g. a true category assignment). Anytime a text file is placed in `/training/data/dir` -the model will update. Anytime a text file is placed in `/testing/data/dir` +should be formatted as `(y, [x1, x2, x3])`, where `y` is some useful label or identifier +(e.g. a true category assignment). Anytime a text file is placed in `/training/data/dir` +the model will update. Anytime a text file is placed in `/testing/data/dir` you will see predictions. With new data, the cluster centers will change!
    diff --git a/docs/mllib-collaborative-filtering.md b/docs/mllib-collaborative-filtering.md index 2094963392295..935cd8dad3b25 100644 --- a/docs/mllib-collaborative-filtering.md +++ b/docs/mllib-collaborative-filtering.md @@ -66,6 +66,7 @@ recommendation model by measuring the Mean Squared Error of rating prediction. {% highlight scala %} import org.apache.spark.mllib.recommendation.ALS +import org.apache.spark.mllib.recommendation.MatrixFactorizationModel import org.apache.spark.mllib.recommendation.Rating // Load and parse the data @@ -95,6 +96,9 @@ val MSE = ratesAndPreds.map { case ((user, product), (r1, r2)) => err * err }.mean() println("Mean Squared Error = " + MSE) + +model.save("myModelPath") +val sameModel = MatrixFactorizationModel.load("myModelPath") {% endhighlight %} If the rating matrix is derived from another source of information (e.g., it is inferred from @@ -181,6 +185,9 @@ public class CollaborativeFiltering { } ).rdd()).mean(); System.out.println("Mean Squared Error = " + MSE); + + model.save("myModelPath"); + MatrixFactorizationModel sameModel = MatrixFactorizationModel.load("myModelPath"); } } {% endhighlight %} @@ -191,13 +198,14 @@ In the following example we load rating data. Each row consists of a user, a pro We use the default ALS.train() method which assumes ratings are explicit. We evaluate the recommendation by measuring the Mean Squared Error of rating prediction. +Note that the Python API does not yet support model save/load but will in the future. + {% highlight python %} -from pyspark.mllib.recommendation import ALS -from numpy import array +from pyspark.mllib.recommendation import ALS, Rating # Load and parse the data data = sc.textFile("data/mllib/als/test.data") -ratings = data.map(lambda line: array([float(x) for x in line.split(',')])) +ratings = data.map(lambda l: l.split(',')).map(lambda l: Rating(int(l[0]), int(l[1]), float(l[2]))) # Build the recommendation model using Alternating Least Squares rank = 10 @@ -205,10 +213,10 @@ numIterations = 20 model = ALS.train(ratings, rank, numIterations) # Evaluate the model on training data -testdata = ratings.map(lambda p: (int(p[0]), int(p[1]))) +testdata = ratings.map(lambda p: (p[0], p[1])) predictions = model.predictAll(testdata).map(lambda r: ((r[0], r[1]), r[2])) ratesAndPreds = ratings.map(lambda r: ((r[0], r[1]), r[2])).join(predictions) -MSE = ratesAndPreds.map(lambda r: (r[1][0] - r[1][1])**2).reduce(lambda x, y: x + y)/ratesAndPreds.count() +MSE = ratesAndPreds.map(lambda r: (r[1][0] - r[1][1])**2).reduce(lambda x, y: x + y) / ratesAndPreds.count() print("Mean Squared Error = " + str(MSE)) {% endhighlight %} @@ -217,7 +225,7 @@ signals), you can use the trainImplicit method to get better results. {% highlight python %} # Build the recommendation model using Alternating Least Squares based on implicit ratings -model = ALS.trainImplicit(ratings, rank, numIterations, alpha = 0.01) +model = ALS.trainImplicit(ratings, rank, numIterations, alpha=0.01) {% endhighlight %}
    diff --git a/docs/mllib-data-types.md b/docs/mllib-data-types.md index 101dc2f8695f3..fe6c1bf7bfd99 100644 --- a/docs/mllib-data-types.md +++ b/docs/mllib-data-types.md @@ -296,6 +296,70 @@ backed by an RDD of its entries. The underlying RDDs of a distributed matrix must be deterministic, because we cache the matrix size. In general the use of non-deterministic RDDs can lead to errors. +### BlockMatrix + +A `BlockMatrix` is a distributed matrix backed by an RDD of `MatrixBlock`s, where a `MatrixBlock` is +a tuple of `((Int, Int), Matrix)`, where the `(Int, Int)` is the index of the block, and `Matrix` is +the sub-matrix at the given index with size `rowsPerBlock` x `colsPerBlock`. +`BlockMatrix` supports methods such as `add` and `multiply` with another `BlockMatrix`. +`BlockMatrix` also has a helper function `validate` which can be used to check whether the +`BlockMatrix` is set up properly. + +
    +
    + +A [`BlockMatrix`](api/scala/index.html#org.apache.spark.mllib.linalg.distributed.BlockMatrix) can be +most easily created from an `IndexedRowMatrix` or `CoordinateMatrix` by calling `toBlockMatrix`. +`toBlockMatrix` creates blocks of size 1024 x 1024 by default. +Users may change the block size by supplying the values through `toBlockMatrix(rowsPerBlock, colsPerBlock)`. + +{% highlight scala %} +import org.apache.spark.mllib.linalg.distributed.{BlockMatrix, CoordinateMatrix, MatrixEntry} + +val entries: RDD[MatrixEntry] = ... // an RDD of (i, j, v) matrix entries +// Create a CoordinateMatrix from an RDD[MatrixEntry]. +val coordMat: CoordinateMatrix = new CoordinateMatrix(entries) +// Transform the CoordinateMatrix to a BlockMatrix +val matA: BlockMatrix = coordMat.toBlockMatrix().cache() + +// Validate whether the BlockMatrix is set up properly. Throws an Exception when it is not valid. +// Nothing happens if it is valid. +matA.validate() + +// Calculate A^T A. +val ata = matA.transpose.multiply(matA) +{% endhighlight %} +
    + +
    + +A [`BlockMatrix`](api/java/org/apache/spark/mllib/linalg/distributed/BlockMatrix.html) can be +most easily created from an `IndexedRowMatrix` or `CoordinateMatrix` by calling `toBlockMatrix`. +`toBlockMatrix` creates blocks of size 1024 x 1024 by default. +Users may change the block size by supplying the values through `toBlockMatrix(rowsPerBlock, colsPerBlock)`. + +{% highlight java %} +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.mllib.linalg.distributed.BlockMatrix; +import org.apache.spark.mllib.linalg.distributed.CoordinateMatrix; +import org.apache.spark.mllib.linalg.distributed.IndexedRowMatrix; + +JavaRDD entries = ... // a JavaRDD of (i, j, v) Matrix Entries +// Create a CoordinateMatrix from a JavaRDD. +CoordinateMatrix coordMat = new CoordinateMatrix(entries.rdd()); +// Transform the CoordinateMatrix to a BlockMatrix +BlockMatrix matA = coordMat.toBlockMatrix().cache(); + +// Validate whether the BlockMatrix is set up properly. Throws an Exception when it is not valid. +// Nothing happens if it is valid. +matA.validate(); + +// Calculate A^T A. +BlockMatrix ata = matA.transpose().multiply(matA); +{% endhighlight %} +
    +
    + ### RowMatrix A `RowMatrix` is a row-oriented distributed matrix without meaningful row indices, backed by an RDD diff --git a/docs/mllib-decision-tree.md b/docs/mllib-decision-tree.md index fc8e732251a30..4695d1cde4901 100644 --- a/docs/mllib-decision-tree.md +++ b/docs/mllib-decision-tree.md @@ -1,7 +1,7 @@ --- layout: global -title: Decision Tree - MLlib -displayTitle: MLlib - Decision Tree +title: Decision Trees - MLlib +displayTitle: MLlib - Decision Trees --- * Table of contents @@ -54,8 +54,8 @@ impurity measure for regression (variance). Variance Regression - $\frac{1}{N} \sum_{i=1}^{N} (x_i - \mu)^2$$y_i$ is label for an instance, - $N$ is the number of instances and $\mu$ is the mean given by $\frac{1}{N} \sum_{i=1}^N x_i$. + $\frac{1}{N} \sum_{i=1}^{N} (y_i - \mu)^2$$y_i$ is label for an instance, + $N$ is the number of instances and $\mu$ is the mean given by $\frac{1}{N} \sum_{i=1}^N y_i$. @@ -194,6 +194,7 @@ maximum tree depth of 5. The test error is calculated to measure the algorithm a
    {% highlight scala %} import org.apache.spark.mllib.tree.DecisionTree +import org.apache.spark.mllib.tree.model.DecisionTreeModel import org.apache.spark.mllib.util.MLUtils // Load and parse the data file. @@ -221,6 +222,9 @@ val labelAndPreds = testData.map { point => val testErr = labelAndPreds.filter(r => r._1 != r._2).count.toDouble / testData.count() println("Test Error = " + testErr) println("Learned classification tree model:\n" + model.toDebugString) + +model.save("myModelPath") +val sameModel = DecisionTreeModel.load("myModelPath") {% endhighlight %}
    @@ -279,10 +283,16 @@ Double testErr = }).count() / testData.count(); System.out.println("Test Error: " + testErr); System.out.println("Learned classification tree model:\n" + model.toDebugString()); + +model.save("myModelPath"); +DecisionTreeModel sameModel = DecisionTreeModel.load("myModelPath"); {% endhighlight %}
+ +Note that the Python API does not yet support model save/load but will in the future. + {% highlight python %} from pyspark.mllib.regression import LabeledPoint from pyspark.mllib.tree import DecisionTree @@ -324,6 +334,7 @@ depth of 5. The Mean Squared Error (MSE) is computed at the end to evaluate
{% highlight scala %} import org.apache.spark.mllib.tree.DecisionTree +import org.apache.spark.mllib.tree.model.DecisionTreeModel import org.apache.spark.mllib.util.MLUtils // Load and parse the data file. @@ -350,6 +361,9 @@ val labelsAndPredictions = testData.map { point => val testMSE = labelsAndPredictions.map{ case(v, p) => math.pow((v - p), 2)}.mean() println("Test Mean Squared Error = " + testMSE) println("Learned regression tree model:\n" + model.toDebugString) + +model.save("myModelPath") +val sameModel = DecisionTreeModel.load("myModelPath") {% endhighlight %}
@@ -414,10 +428,16 @@ Double testMSE = }) / data.count(); System.out.println("Test Mean Squared Error: " + testMSE); System.out.println("Learned regression tree model:\n" + model.toDebugString()); + +model.save("myModelPath"); +DecisionTreeModel sameModel = DecisionTreeModel.load("myModelPath"); {% endhighlight %}
+ +Note that the Python API does not yet support model save/load but will in the future. + {% highlight python %} from pyspark.mllib.regression import LabeledPoint from pyspark.mllib.tree import DecisionTree diff --git a/docs/mllib-ensembles.md b/docs/mllib-ensembles.md index 23ede04b62d5b..ddae84165f8a9 100644 --- a/docs/mllib-ensembles.md +++ b/docs/mllib-ensembles.md @@ -98,6 +98,7 @@ The test error is calculated to measure the algorithm accuracy.
{% highlight scala %} import org.apache.spark.mllib.tree.RandomForest +import org.apache.spark.mllib.tree.model.RandomForestModel import org.apache.spark.mllib.util.MLUtils // Load and parse the data file. @@ -127,6 +128,9 @@ val labelAndPreds = testData.map { point => val testErr = labelAndPreds.filter(r => r._1 != r._2).count.toDouble / testData.count() println("Test Error = " + testErr) println("Learned classification forest model:\n" + model.toDebugString) + +model.save("myModelPath") +val sameModel = RandomForestModel.load("myModelPath") {% endhighlight %}
@@ -188,10 +192,16 @@ Double testErr = }).count() / testData.count(); System.out.println("Test Error: " + testErr); System.out.println("Learned classification forest model:\n" + model.toDebugString()); + +model.save("myModelPath"); +RandomForestModel sameModel = RandomForestModel.load("myModelPath"); {% endhighlight %}
+ +Note that the Python API does not yet support model save/load but will in the future. + {% highlight python %} from pyspark.mllib.tree import RandomForest from pyspark.mllib.util import MLUtils @@ -235,6 +245,7 @@ The Mean Squared Error (MSE) is computed at the end to evaluate
{% highlight scala %} import org.apache.spark.mllib.tree.RandomForest +import org.apache.spark.mllib.tree.model.RandomForestModel import org.apache.spark.mllib.util.MLUtils // Load and parse the data file. @@ -264,6 +275,9 @@ val labelsAndPredictions = testData.map { point => val testMSE = labelsAndPredictions.map{ case(v, p) => math.pow((v - p), 2)}.mean() println("Test Mean Squared Error = " + testMSE) println("Learned regression forest model:\n" + model.toDebugString) + +model.save("myModelPath") +val sameModel = RandomForestModel.load("myModelPath") {% endhighlight %}
@@ -328,10 +342,16 @@ Double testMSE = }) / testData.count(); System.out.println("Test Mean Squared Error: " + testMSE); System.out.println("Learned regression forest model:\n" + model.toDebugString()); + +model.save("myModelPath"); +RandomForestModel sameModel = RandomForestModel.load("myModelPath"); {% endhighlight %}
+ +Note that the Python API does not yet support model save/load but will in the future. + {% highlight python %} from pyspark.mllib.tree import RandomForest from pyspark.mllib.util import MLUtils @@ -427,10 +447,19 @@ We omit some decision tree parameters since those are covered in the [decision t * **`algo`**: The algorithm or task (classification vs. regression) is set using the tree [Strategy] parameter. +#### Validation while training -### Examples +Gradient boosting can overfit when trained with more trees. In order to prevent overfitting, it is useful to validate while +training. The method runWithValidation has been provided to make use of this option. It takes a pair of RDD's as arguments, the +first one being the training dataset and the second being the validation dataset. -GBTs currently have APIs in Scala and Java. Examples in both languages are shown below. +The training is stopped when the improvement in the validation error is not more than a certain tolerance +(supplied by the `validationTol` argument in `BoostingStrategy`). In practice, the validation error +decreases initially and later increases. There might be cases in which the validation error does not change monotonically, +and the user is advised to set a large enough negative tolerance and examine the validation curve to to tune the number of +iterations. + +### Examples #### Classification @@ -446,6 +475,7 @@ The test error is calculated to measure the algorithm accuracy. {% highlight scala %} import org.apache.spark.mllib.tree.GradientBoostedTrees import org.apache.spark.mllib.tree.configuration.BoostingStrategy +import org.apache.spark.mllib.tree.model.GradientBoostedTreesModel import org.apache.spark.mllib.util.MLUtils // Load and parse the data file. @@ -458,7 +488,7 @@ val (trainingData, testData) = (splits(0), splits(1)) // The defaultParams for Classification use LogLoss by default. val boostingStrategy = BoostingStrategy.defaultParams("Classification") boostingStrategy.numIterations = 3 // Note: Use more iterations in practice. -boostingStrategy.treeStrategy.numClassesForClassification = 2 +boostingStrategy.treeStrategy.numClasses = 2 boostingStrategy.treeStrategy.maxDepth = 5 // Empty categoricalFeaturesInfo indicates all features are continuous. boostingStrategy.treeStrategy.categoricalFeaturesInfo = Map[Int, Int]() @@ -473,6 +503,9 @@ val labelAndPreds = testData.map { point => val testErr = labelAndPreds.filter(r => r._1 != r._2).count.toDouble / testData.count() println("Test Error = " + testErr) println("Learned classification GBT model:\n" + model.toDebugString) + +model.save("myModelPath") +val sameModel = GradientBoostedTreesModel.load("myModelPath") {% endhighlight %}
@@ -534,6 +567,38 @@ Double testErr = }).count() / testData.count(); System.out.println("Test Error: " + testErr); System.out.println("Learned classification GBT model:\n" + model.toDebugString()); + +model.save("myModelPath"); +GradientBoostedTreesModel sameModel = GradientBoostedTreesModel.load("myModelPath"); +{% endhighlight %} + + +
+ +Note that the Python API does not yet support model save/load but will in the future. + +{% highlight python %} +from pyspark.mllib.tree import GradientBoostedTrees +from pyspark.mllib.util import MLUtils + +# Load and parse the data file. +data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") +# Split the data into training and test sets (30% held out for testing) +(trainingData, testData) = data.randomSplit([0.7, 0.3]) + +# Train a GradientBoostedTrees model. +# Notes: (a) Empty categoricalFeaturesInfo indicates all features are continuous. +# (b) Use more iterations in practice. +model = GradientBoostedTrees.trainClassifier(trainingData, + categoricalFeaturesInfo={}, numIterations=3) + +# Evaluate model on test instances and compute test error +predictions = model.predict(testData.map(lambda x: x.features)) +labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions) +testErr = labelsAndPredictions.filter(lambda (v, p): v != p).count() / float(testData.count()) +print('Test Error = ' + str(testErr)) +print('Learned classification GBT model:') +print(model.toDebugString()) {% endhighlight %}
@@ -554,6 +619,7 @@ The Mean Squared Error (MSE) is computed at the end to evaluate {% highlight scala %} import org.apache.spark.mllib.tree.GradientBoostedTrees import org.apache.spark.mllib.tree.configuration.BoostingStrategy +import org.apache.spark.mllib.tree.model.GradientBoostedTreesModel import org.apache.spark.mllib.util.MLUtils // Load and parse the data file. @@ -580,6 +646,9 @@ val labelsAndPredictions = testData.map { point => val testMSE = labelsAndPredictions.map{ case(v, p) => math.pow((v - p), 2)}.mean() println("Test Mean Squared Error = " + testMSE) println("Learned regression GBT model:\n" + model.toDebugString) + +model.save("myModelPath") +val sameModel = GradientBoostedTreesModel.load("myModelPath") {% endhighlight %} @@ -647,6 +716,38 @@ Double testMSE = }) / data.count(); System.out.println("Test Mean Squared Error: " + testMSE); System.out.println("Learned regression GBT model:\n" + model.toDebugString()); + +model.save("myModelPath"); +GradientBoostedTreesModel sameModel = GradientBoostedTreesModel.load("myModelPath"); +{% endhighlight %} + + +
+ +Note that the Python API does not yet support model save/load but will in the future. + +{% highlight python %} +from pyspark.mllib.tree import GradientBoostedTrees +from pyspark.mllib.util import MLUtils + +# Load and parse the data file. +data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") +# Split the data into training and test sets (30% held out for testing) +(trainingData, testData) = data.randomSplit([0.7, 0.3]) + +# Train a GradientBoostedTrees model. +# Notes: (a) Empty categoricalFeaturesInfo indicates all features are continuous. +# (b) Use more iterations in practice. +model = GradientBoostedTrees.trainRegressor(trainingData, + categoricalFeaturesInfo={}, numIterations=3) + +# Evaluate model on test instances and compute test error +predictions = model.predict(testData.map(lambda x: x.features)) +labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions) +testMSE = labelsAndPredictions.map(lambda (v, p): (v - p) * (v - p)).sum() / float(testData.count()) +print('Test Mean Squared Error = ' + str(testMSE)) +print('Learned regression GBT model:') +print(model.toDebugString()) {% endhighlight %}
diff --git a/docs/mllib-feature-extraction.md b/docs/mllib-feature-extraction.md index 197bc77d506c6..80842b27effd8 100644 --- a/docs/mllib-feature-extraction.md +++ b/docs/mllib-feature-extraction.md @@ -240,11 +240,11 @@ following parameters in the constructor: * `withMean` False by default. Centers the data with mean before scaling. It will build a dense output, so this does not work on sparse input and will raise an exception. -* `withStd` True by default. Scales the data to unit variance. +* `withStd` True by default. Scales the data to unit standard deviation. We provide a [`fit`](api/scala/index.html#org.apache.spark.mllib.feature.StandardScaler) method in `StandardScaler` which can take an input of `RDD[Vector]`, learn the summary statistics, and then -return a model which can transform the input dataset into unit variance and/or zero mean features +return a model which can transform the input dataset into unit standard deviation and/or zero mean features depending how we configure the `StandardScaler`. This model implements [`VectorTransformer`](api/scala/index.html#org.apache.spark.mllib.feature.VectorTransformer) @@ -257,7 +257,7 @@ for that feature. ### Example The example below demonstrates how to load a dataset in libsvm format, and standardize the features -so that the new features have unit variance and/or zero mean. +so that the new features have unit standard deviation and/or zero mean.
@@ -271,6 +271,8 @@ val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") val scaler1 = new StandardScaler().fit(data.map(x => x.features)) val scaler2 = new StandardScaler(withMean = true, withStd = true).fit(data.map(x => x.features)) +// scaler3 is an identical model to scaler2, and will produce identical transformations +val scaler3 = new StandardScalerModel(scaler2.std, scaler2.mean) // data1 will be unit variance. val data1 = data.map(x => (x.label, scaler1.transform(x.features))) @@ -294,6 +296,9 @@ features = data.map(lambda x: x.features) scaler1 = StandardScaler().fit(features) scaler2 = StandardScaler(withMean=True, withStd=True).fit(features) +# scaler3 is an identical model to scaler2, and will produce identical transformations +scaler3 = StandardScalerModel(scaler2.std, scaler2.mean) + # data1 will be unit variance. data1 = label.zip(scaler1.transform(features)) @@ -370,3 +375,105 @@ data2 = labels.zip(normalizer2.transform(features)) {% endhighlight %}
+ +## Feature selection +[Feature selection](http://en.wikipedia.org/wiki/Feature_selection) allows selecting the most relevant features for use in model construction. Feature selection reduces the size of the vector space and, in turn, the complexity of any subsequent operation with vectors. The number of features to select can be tuned using a held-out validation set. + +### ChiSqSelector +[`ChiSqSelector`](api/scala/index.html#org.apache.spark.mllib.feature.ChiSqSelector) stands for Chi-Squared feature selection. It operates on labeled data with categorical features. `ChiSqSelector` orders features based on a Chi-Squared test of independence from the class, and then filters (selects) the top features which are most closely related to the label. + +#### Model Fitting + +[`ChiSqSelector`](api/scala/index.html#org.apache.spark.mllib.feature.ChiSqSelector) has the +following parameters in the constructor: + +* `numTopFeatures` number of top features that the selector will select (filter). + +We provide a [`fit`](api/scala/index.html#org.apache.spark.mllib.feature.ChiSqSelector) method in +`ChiSqSelector` which can take an input of `RDD[LabeledPoint]` with categorical features, learn the summary statistics, and then +return a `ChiSqSelectorModel` which can transform an input dataset into the reduced feature space. + +This model implements [`VectorTransformer`](api/scala/index.html#org.apache.spark.mllib.feature.VectorTransformer) +which can apply the Chi-Squared feature selection on a `Vector` to produce a reduced `Vector` or on +an `RDD[Vector]` to produce a reduced `RDD[Vector]`. + +Note that the user can also construct a `ChiSqSelectorModel` by hand by providing an array of selected feature indices (which must be sorted in ascending order). + +#### Example + +The following example shows the basic use of ChiSqSelector. + +
+
+{% highlight scala %} +import org.apache.spark.SparkContext._ +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.util.MLUtils + +// Load some data in libsvm format +val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") +// Discretize data in 16 equal bins since ChiSqSelector requires categorical features +val discretizedData = data.map { lp => + LabeledPoint(lp.label, Vectors.dense(lp.features.toArray.map { x => x / 16 } ) ) +} +// Create ChiSqSelector that will select 50 features +val selector = new ChiSqSelector(50) +// Create ChiSqSelector model (selecting features) +val transformer = selector.fit(discretizedData) +// Filter the top 50 features from each feature vector +val filteredData = discretizedData.map { lp => + LabeledPoint(lp.label, transformer.transform(lp.features)) +} +{% endhighlight %} +
+ +
+{% highlight java %} +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.mllib.feature.ChiSqSelector; +import org.apache.spark.mllib.feature.ChiSqSelectorModel; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.util.MLUtils; + +SparkConf sparkConf = new SparkConf().setAppName("JavaChiSqSelector"); +JavaSparkContext sc = new JavaSparkContext(sparkConf); +JavaRDD points = MLUtils.loadLibSVMFile(sc.sc(), + "data/mllib/sample_libsvm_data.txt").toJavaRDD().cache(); + +// Discretize data in 16 equal bins since ChiSqSelector requires categorical features +JavaRDD discretizedData = points.map( + new Function() { + @Override + public LabeledPoint call(LabeledPoint lp) { + final double[] discretizedFeatures = new double[lp.features().size()]; + for (int i = 0; i < lp.features().size(); ++i) { + discretizedFeatures[i] = lp.features().apply(i) / 16; + } + return new LabeledPoint(lp.label(), Vectors.dense(discretizedFeatures)); + } + }); + +// Create ChiSqSelector that will select 50 features +ChiSqSelector selector = new ChiSqSelector(50); +// Create ChiSqSelector model (selecting features) +final ChiSqSelectorModel transformer = selector.fit(discretizedData.rdd()); +// Filter the top 50 features from each feature vector +JavaRDD filteredData = discretizedData.map( + new Function() { + @Override + public LabeledPoint call(LabeledPoint lp) { + return new LabeledPoint(lp.label(), transformer.transform(lp.features())); + } + } +); + +sc.stop(); +{% endhighlight %} +
+
+ diff --git a/docs/mllib-frequent-pattern-mining.md b/docs/mllib-frequent-pattern-mining.md new file mode 100644 index 0000000000000..9fd9be0dd01b1 --- /dev/null +++ b/docs/mllib-frequent-pattern-mining.md @@ -0,0 +1,98 @@ +--- +layout: global +title: Frequent Pattern Mining - MLlib +displayTitle: MLlib - Frequent Pattern Mining +--- + +Mining frequent items, itemsets, subsequences, or other substructures is usually among the +first steps to analyze a large-scale dataset, which has been an active research topic in +data mining for years. +We refer users to Wikipedia's [association rule learning](http://en.wikipedia.org/wiki/Association_rule_learning) +for more information. +MLlib provides a parallel implementation of FP-growth, +a popular algorithm to mining frequent itemsets. + +## FP-growth + +The FP-growth algorithm is described in the paper +[Han et al., Mining frequent patterns without candidate generation](http://dx.doi.org/10.1145/335191.335372), +where "FP" stands for frequent pattern. +Given a dataset of transactions, the first step of FP-growth is to calculate item frequencies and identify frequent items. +Different from [Apriori-like](http://en.wikipedia.org/wiki/Apriori_algorithm) algorithms designed for the same purpose, +the second step of FP-growth uses a suffix tree (FP-tree) structure to encode transactions without generating candidate sets +explicitly, which are usually expensive to generate. +After the second step, the frequent itemsets can be extracted from the FP-tree. +In MLlib, we implemented a parallel version of FP-growth called PFP, +as described in [Li et al., PFP: Parallel FP-growth for query recommendation](http://dx.doi.org/10.1145/1454008.1454027). +PFP distributes the work of growing FP-trees based on the suffices of transactions, +and hence more scalable than a single-machine implementation. +We refer users to the papers for more details. + +MLlib's FP-growth implementation takes the following (hyper-)parameters: + +* `minSupport`: the minimum support for an itemset to be identified as frequent. + For example, if an item appears 3 out of 5 transactions, it has a support of 3/5=0.6. +* `numPartitions`: the number of partitions used to distribute the work. + +**Examples** + +
+
+ +[`FPGrowth`](api/java/org/apache/spark/mllib/fpm/FPGrowth.html) implements the +FP-growth algorithm. +It take a `JavaRDD` of transactions, where each transaction is an `Iterable` of items of a generic type. +Calling `FPGrowth.run` with transactions returns an +[`FPGrowthModel`](api/java/org/apache/spark/mllib/fpm/FPGrowthModel.html) +that stores the frequent itemsets with their frequencies. + +{% highlight scala %} +import org.apache.spark.rdd.RDD +import org.apache.spark.mllib.fpm.{FPGrowth, FPGrowthModel} + +val transactions: RDD[Array[String]] = ... + +val fpg = new FPGrowth() + .setMinSupport(0.2) + .setNumPartitions(10) +val model = fpg.run(transactions) + +model.freqItemsets.collect().foreach { itemset => + println(itemset.items.mkString("[", ",", "]") + ", " + itemset.freq) +} +{% endhighlight %} + +
+ +
+ +[`FPGrowth`](api/java/org/apache/spark/mllib/fpm/FPGrowth.html) implements the +FP-growth algorithm. +It take an `RDD` of transactions, where each transaction is an `Array` of items of a generic type. +Calling `FPGrowth.run` with transactions returns an +[`FPGrowthModel`](api/java/org/apache/spark/mllib/fpm/FPGrowthModel.html) +that stores the frequent itemsets with their frequencies. + +{% highlight java %} +import java.util.List; + +import com.google.common.base.Joiner; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.mllib.fpm.FPGrowth; +import org.apache.spark.mllib.fpm.FPGrowthModel; + +JavaRDD> transactions = ... + +FPGrowth fpg = new FPGrowth() + .setMinSupport(0.2) + .setNumPartitions(10); +FPGrowthModel model = fpg.run(transactions); + +for (FPGrowth.FreqItemset itemset: model.freqItemsets().toJavaRDD().collect()) { + System.out.println("[" + Joiner.on(",").join(s.javaItems()) + "], " + s.freq()); +} +{% endhighlight %} + +
+
diff --git a/docs/mllib-guide.md b/docs/mllib-guide.md index 39c64d06926bf..4c7a7d9115ca1 100644 --- a/docs/mllib-guide.md +++ b/docs/mllib-guide.md @@ -1,6 +1,8 @@ --- layout: global -title: Machine Learning Library (MLlib) Programming Guide +title: MLlib +displayTitle: Machine Learning Library (MLlib) Guide +description: MLlib machine learning library overview for Spark SPARK_VERSION_SHORT --- MLlib is Spark's scalable machine learning library consisting of common learning algorithms and utilities, @@ -19,14 +21,21 @@ filtering, dimensionality reduction, as well as underlying optimization primitiv * [naive Bayes](mllib-naive-bayes.html) * [decision trees](mllib-decision-tree.html) * [ensembles of trees](mllib-ensembles.html) (Random Forests and Gradient-Boosted Trees) + * [isotonic regression](mllib-isotonic-regression.html) * [Collaborative filtering](mllib-collaborative-filtering.html) * alternating least squares (ALS) * [Clustering](mllib-clustering.html) - * k-means + * [k-means](mllib-clustering.html#k-means) + * [Gaussian mixture](mllib-clustering.html#gaussian-mixture) + * [power iteration clustering (PIC)](mllib-clustering.html#power-iteration-clustering-pic) + * [latent Dirichlet allocation (LDA)](mllib-clustering.html#latent-dirichlet-allocation-lda) + * [streaming k-means](mllib-clustering.html#streaming-k-means) * [Dimensionality reduction](mllib-dimensionality-reduction.html) * singular value decomposition (SVD) * principal component analysis (PCA) * [Feature extraction and transformation](mllib-feature-extraction.html) +* [Frequent pattern mining](mllib-frequent-pattern-mining.html) + * FP-growth * [Optimization (developer)](mllib-optimization.html) * stochastic gradient descent * limited-memory BFGS (L-BFGS) @@ -37,7 +46,7 @@ and the migration guide below will explain all changes between releases. # spark.ml: high-level APIs for ML pipelines -Spark 1.2 includes a new package called `spark.ml`, which aims to provide a uniform set of +Spark 1.2 introduced a new package called `spark.ml`, which aims to provide a uniform set of high-level APIs that help users create and tune practical machine learning pipelines. It is currently an alpha component, and we would like to hear back from the community about how it fits real-world use cases and how it could be improved. @@ -52,149 +61,53 @@ See the **[spark.ml programming guide](ml-guide.html)** for more information on # Dependencies -MLlib uses the linear algebra package [Breeze](http://www.scalanlp.org/), -which depends on [netlib-java](https://github.com/fommil/netlib-java), -and [jblas](https://github.com/mikiobraun/jblas). -`netlib-java` and `jblas` depend on native Fortran routines. -You need to install the +MLlib uses the linear algebra package +[Breeze](http://www.scalanlp.org/), which depends on +[netlib-java](https://github.com/fommil/netlib-java) for optimised +numerical processing. If natives are not available at runtime, you +will see a warning message and a pure JVM implementation will be used +instead. + +To learn more about the benefits and background of system optimised +natives, you may wish to watch Sam Halliday's ScalaX talk on +[High Performance Linear Algebra in Scala](http://fommil.github.io/scalax14/#/)). + +Due to licensing issues with runtime proprietary binaries, we do not +include `netlib-java`'s native proxies by default. To configure +`netlib-java` / Breeze to use system optimised binaries, include +`com.github.fommil.netlib:all:1.1.2` (or build Spark with +`-Pnetlib-lgpl`) as a dependency of your project and read the +[netlib-java](https://github.com/fommil/netlib-java) documentation for +your platform's additional installation instructions. + +MLlib also uses [jblas](https://github.com/mikiobraun/jblas) which +will require you to install the [gfortran runtime library](https://github.com/mikiobraun/jblas/wiki/Missing-Libraries) if it is not already present on your nodes. -MLlib will throw a linking error if it cannot detect these libraries automatically. -Due to license issues, we do not include `netlib-java`'s native libraries in MLlib's -dependency set under default settings. -If no native library is available at runtime, you will see a warning message. -To use native libraries from `netlib-java`, please build Spark with `-Pnetlib-lgpl` or -include `com.github.fommil.netlib:all:1.1.2` as a dependency of your project. -If you want to use optimized BLAS/LAPACK libraries such as -[OpenBLAS](http://www.openblas.net/), please link its shared libraries to -`/usr/lib/libblas.so.3` and `/usr/lib/liblapack.so.3`, respectively. -BLAS/LAPACK libraries on worker nodes should be built without multithreading. - -To use MLlib in Python, you will need [NumPy](http://www.numpy.org) version 1.4 or newer. + +To use MLlib in Python, you will need [NumPy](http://www.numpy.org) +version 1.4 or newer. --- # Migration Guide -## From 1.1 to 1.2 - -The only API changes in MLlib v1.2 are in -[`DecisionTree`](api/scala/index.html#org.apache.spark.mllib.tree.DecisionTree), -which continues to be an experimental API in MLlib 1.2: - -1. *(Breaking change)* The Scala API for classification takes a named argument specifying the number -of classes. In MLlib v1.1, this argument was called `numClasses` in Python and -`numClassesForClassification` in Scala. In MLlib v1.2, the names are both set to `numClasses`. -This `numClasses` parameter is specified either via -[`Strategy`](api/scala/index.html#org.apache.spark.mllib.tree.configuration.Strategy) -or via [`DecisionTree`](api/scala/index.html#org.apache.spark.mllib.tree.DecisionTree) -static `trainClassifier` and `trainRegressor` methods. - -2. *(Breaking change)* The API for -[`Node`](api/scala/index.html#org.apache.spark.mllib.tree.model.Node) has changed. -This should generally not affect user code, unless the user manually constructs decision trees -(instead of using the `trainClassifier` or `trainRegressor` methods). -The tree `Node` now includes more information, including the probability of the predicted label -(for classification). - -3. Printing methods' output has changed. The `toString` (Scala/Java) and `__repr__` (Python) methods used to print the full model; they now print a summary. For the full model, use `toDebugString`. - -Examples in the Spark distribution and examples in the -[Decision Trees Guide](mllib-decision-tree.html#examples) have been updated accordingly. - -## From 1.0 to 1.1 - -The only API changes in MLlib v1.1 are in -[`DecisionTree`](api/scala/index.html#org.apache.spark.mllib.tree.DecisionTree), -which continues to be an experimental API in MLlib 1.1: - -1. *(Breaking change)* The meaning of tree depth has been changed by 1 in order to match -the implementations of trees in -[scikit-learn](http://scikit-learn.org/stable/modules/classes.html#module-sklearn.tree) -and in [rpart](http://cran.r-project.org/web/packages/rpart/index.html). -In MLlib v1.0, a depth-1 tree had 1 leaf node, and a depth-2 tree had 1 root node and 2 leaf nodes. -In MLlib v1.1, a depth-0 tree has 1 leaf node, and a depth-1 tree has 1 root node and 2 leaf nodes. -This depth is specified by the `maxDepth` parameter in -[`Strategy`](api/scala/index.html#org.apache.spark.mllib.tree.configuration.Strategy) -or via [`DecisionTree`](api/scala/index.html#org.apache.spark.mllib.tree.DecisionTree) -static `trainClassifier` and `trainRegressor` methods. - -2. *(Non-breaking change)* We recommend using the newly added `trainClassifier` and `trainRegressor` -methods to build a [`DecisionTree`](api/scala/index.html#org.apache.spark.mllib.tree.DecisionTree), -rather than using the old parameter class `Strategy`. These new training methods explicitly -separate classification and regression, and they replace specialized parameter types with -simple `String` types. - -Examples of the new, recommended `trainClassifier` and `trainRegressor` are given in the -[Decision Trees Guide](mllib-decision-tree.html#examples). - -## From 0.9 to 1.0 - -In MLlib v1.0, we support both dense and sparse input in a unified way, which introduces a few -breaking changes. If your data is sparse, please store it in a sparse format instead of dense to -take advantage of sparsity in both storage and computation. Details are described below. - -
-
- -We used to represent a feature vector by `Array[Double]`, which is replaced by -[`Vector`](api/scala/index.html#org.apache.spark.mllib.linalg.Vector) in v1.0. Algorithms that used -to accept `RDD[Array[Double]]` now take -`RDD[Vector]`. [`LabeledPoint`](api/scala/index.html#org.apache.spark.mllib.regression.LabeledPoint) -is now a wrapper of `(Double, Vector)` instead of `(Double, Array[Double])`. Converting -`Array[Double]` to `Vector` is straightforward: - -{% highlight scala %} -import org.apache.spark.mllib.linalg.{Vector, Vectors} - -val array: Array[Double] = ... // a double array -val vector: Vector = Vectors.dense(array) // a dense vector -{% endhighlight %} - -[`Vectors`](api/scala/index.html#org.apache.spark.mllib.linalg.Vectors$) provides factory methods to create sparse vectors. - -*Note*: Scala imports `scala.collection.immutable.Vector` by default, so you have to import `org.apache.spark.mllib.linalg.Vector` explicitly to use MLlib's `Vector`. - -
- -
- -We used to represent a feature vector by `double[]`, which is replaced by -[`Vector`](api/java/index.html?org/apache/spark/mllib/linalg/Vector.html) in v1.0. Algorithms that used -to accept `RDD` now take -`RDD`. [`LabeledPoint`](api/java/index.html?org/apache/spark/mllib/regression/LabeledPoint.html) -is now a wrapper of `(double, Vector)` instead of `(double, double[])`. Converting `double[]` to -`Vector` is straightforward: - -{% highlight java %} -import org.apache.spark.mllib.linalg.Vector; -import org.apache.spark.mllib.linalg.Vectors; - -double[] array = ... // a double array -Vector vector = Vectors.dense(array); // a dense vector -{% endhighlight %} - -[`Vectors`](api/scala/index.html#org.apache.spark.mllib.linalg.Vectors$) provides factory methods to -create sparse vectors. +For the `spark.ml` package, please see the [spark.ml Migration Guide](ml-guide.html#migration-guide). -
- -
+## From 1.2 to 1.3 -We used to represent a labeled feature vector in a NumPy array, where the first entry corresponds to -the label and the rest are features. This representation is replaced by class -[`LabeledPoint`](api/python/pyspark.mllib.regression.LabeledPoint-class.html), which takes both -dense and sparse feature vectors. +In the `spark.mllib` package, there were several breaking changes. The first change (in `ALS`) is the only one in a component not marked as Alpha or Experimental. -{% highlight python %} -from pyspark.mllib.linalg import SparseVector -from pyspark.mllib.regression import LabeledPoint +* *(Breaking change)* In [`ALS`](api/scala/index.html#org.apache.spark.mllib.recommendation.ALS), the extraneous method `solveLeastSquares` has been removed. The `DeveloperApi` method `analyzeBlocks` was also removed. +* *(Breaking change)* [`StandardScalerModel`](api/scala/index.html#org.apache.spark.mllib.feature.StandardScalerModel) remains an Alpha component. In it, the `variance` method has been replaced with the `std` method. To compute the column variance values returned by the original `variance` method, simply square the standard deviation values returned by `std`. +* *(Breaking change)* [`StreamingLinearRegressionWithSGD`](api/scala/index.html#org.apache.spark.mllib.regression.StreamingLinearRegressionWithSGD) remains an Experimental component. In it, there were two changes: + * The constructor taking arguments was removed in favor of a builder patten using the default constructor plus parameter setter methods. + * Variable `model` is no longer public. +* *(Breaking change)* [`DecisionTree`](api/scala/index.html#org.apache.spark.mllib.tree.DecisionTree) remains an Experimental component. In it and its associated classes, there were several changes: + * In `DecisionTree`, the deprecated class method `train` has been removed. (The object/static `train` methods remain.) + * In `Strategy`, the `checkpointDir` parameter has been removed. Checkpointing is still supported, but the checkpoint directory must be set before calling tree and tree ensemble training. +* `PythonMLlibAPI` (the interface between Scala/Java and Python for MLlib) was a public API but is now private, declared `private[python]`. This was never meant for external use. -# Create a labeled point with a positive label and a dense feature vector. -pos = LabeledPoint(1.0, [1.0, 0.0, 3.0]) +## Previous Spark Versions -# Create a labeled point with a negative label and a sparse feature vector. -neg = LabeledPoint(0.0, SparseVector(3, [0, 2], [1.0, 3.0])) -{% endhighlight %} -
-
+Earlier migration guides are archived [on this page](mllib-migration-guides.html). diff --git a/docs/mllib-isotonic-regression.md b/docs/mllib-isotonic-regression.md new file mode 100644 index 0000000000000..12fb29d426741 --- /dev/null +++ b/docs/mllib-isotonic-regression.md @@ -0,0 +1,155 @@ +--- +layout: global +title: Naive Bayes - MLlib +displayTitle: MLlib - Regression +--- + +## Isotonic regression +[Isotonic regression](http://en.wikipedia.org/wiki/Isotonic_regression) +belongs to the family of regression algorithms. Formally isotonic regression is a problem where +given a finite set of real numbers `$Y = {y_1, y_2, ..., y_n}$` representing observed responses +and `$X = {x_1, x_2, ..., x_n}$` the unknown response values to be fitted +finding a function that minimises + +`\begin{equation} + f(x) = \sum_{i=1}^n w_i (y_i - x_i)^2 +\end{equation}` + +with respect to complete order subject to +`$x_1\le x_2\le ...\le x_n$` where `$w_i$` are positive weights. +The resulting function is called isotonic regression and it is unique. +It can be viewed as least squares problem under order restriction. +Essentially isotonic regression is a +[monotonic function](http://en.wikipedia.org/wiki/Monotonic_function) +best fitting the original data points. + +MLlib supports a +[pool adjacent violators algorithm](http://doi.org/10.1198/TECH.2010.10111) +which uses an approach to +[parallelizing isotonic regression](http://doi.org/10.1007/978-3-642-99789-1_10). +The training input is a RDD of tuples of three double values that represent +label, feature and weight in this order. Additionally IsotonicRegression algorithm has one +optional parameter called $isotonic$ defaulting to true. +This argument specifies if the isotonic regression is +isotonic (monotonically increasing) or antitonic (monotonically decreasing). + +Training returns an IsotonicRegressionModel that can be used to predict +labels for both known and unknown features. The result of isotonic regression +is treated as piecewise linear function. The rules for prediction therefore are: + +* If the prediction input exactly matches a training feature + then associated prediction is returned. In case there are multiple predictions with the same + feature then one of them is returned. Which one is undefined + (same as java.util.Arrays.binarySearch). +* If the prediction input is lower or higher than all training features + then prediction with lowest or highest feature is returned respectively. + In case there are multiple predictions with the same feature + then the lowest or highest is returned respectively. +* If the prediction input falls between two training features then prediction is treated + as piecewise linear function and interpolated value is calculated from the + predictions of the two closest features. In case there are multiple values + with the same feature then the same rules as in previous point are used. + +### Examples + +
+
+Data are read from a file where each line has a format label,feature +i.e. 4710.28,500.00. The data are split to training and testing set. +Model is created using the training set and a mean squared error is calculated from the predicted +labels and real labels in the test set. + +{% highlight scala %} +import org.apache.spark.mllib.regression.IsotonicRegression + +val data = sc.textFile("data/mllib/sample_isotonic_regression_data.txt") + +// Create label, feature, weight tuples from input data with weight set to default value 1.0. +val parsedData = data.map { line => + val parts = line.split(',').map(_.toDouble) + (parts(0), parts(1), 1.0) +} + +// Split data into training (60%) and test (40%) sets. +val splits = parsedData.randomSplit(Array(0.6, 0.4), seed = 11L) +val training = splits(0) +val test = splits(1) + +// Create isotonic regression model from training data. +// Isotonic parameter defaults to true so it is only shown for demonstration +val model = new IsotonicRegression().setIsotonic(true).run(training) + +// Create tuples of predicted and real labels. +val predictionAndLabel = test.map { point => + val predictedLabel = model.predict(point._2) + (predictedLabel, point._1) +} + +// Calculate mean squared error between predicted and real labels. +val meanSquaredError = predictionAndLabel.map{case(p, l) => math.pow((p - l), 2)}.mean() +println("Mean Squared Error = " + meanSquaredError) +{% endhighlight %} +
+ +
+Data are read from a file where each line has a format label,feature +i.e. 4710.28,500.00. The data are split to training and testing set. +Model is created using the training set and a mean squared error is calculated from the predicted +labels and real labels in the test set. + +{% highlight java %} +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaDoubleRDD; +import org.apache.spark.api.java.JavaPairRDD; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.api.java.function.PairFunction; +import org.apache.spark.mllib.regression.IsotonicRegressionModel; +import scala.Tuple2; +import scala.Tuple3; + +JavaRDD data = sc.textFile("data/mllib/sample_isotonic_regression_data.txt"); + +// Create label, feature, weight tuples from input data with weight set to default value 1.0. +JavaRDD> parsedData = data.map( + new Function>() { + public Tuple3 call(String line) { + String[] parts = line.split(","); + return new Tuple3<>(new Double(parts[0]), new Double(parts[1]), 1.0); + } + } +); + +// Split data into training (60%) and test (40%) sets. +JavaRDD>[] splits = parsedData.randomSplit(new double[] {0.6, 0.4}, 11L); +JavaRDD> training = splits[0]; +JavaRDD> test = splits[1]; + +// Create isotonic regression model from training data. +// Isotonic parameter defaults to true so it is only shown for demonstration +IsotonicRegressionModel model = new IsotonicRegression().setIsotonic(true).run(training); + +// Create tuples of predicted and real labels. +JavaPairRDD predictionAndLabel = test.mapToPair( + new PairFunction, Double, Double>() { + @Override public Tuple2 call(Tuple3 point) { + Double predictedLabel = model.predict(point._2()); + return new Tuple2(predictedLabel, point._1()); + } + } +); + +// Calculate mean squared error between predicted and real labels. +Double meanSquaredError = new JavaDoubleRDD(predictionAndLabel.map( + new Function, Object>() { + @Override public Object call(Tuple2 pl) { + return Math.pow(pl._1() - pl._2(), 2); + } + } +).rdd()).mean(); + +System.out.println("Mean Squared Error = " + meanSquaredError); +{% endhighlight %} +
+
\ No newline at end of file diff --git a/docs/mllib-linear-methods.md b/docs/mllib-linear-methods.md index 44b7f67c57734..d9fc63b37d116 100644 --- a/docs/mllib-linear-methods.md +++ b/docs/mllib-linear-methods.md @@ -190,7 +190,7 @@ error. {% highlight scala %} import org.apache.spark.SparkContext -import org.apache.spark.mllib.classification.SVMWithSGD +import org.apache.spark.mllib.classification.{SVMModel, SVMWithSGD} import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.linalg.Vectors @@ -222,6 +222,9 @@ val metrics = new BinaryClassificationMetrics(scoreAndLabels) val auROC = metrics.areaUnderROC() println("Area under ROC = " + auROC) + +model.save("myModelPath") +val sameModel = SVMModel.load("myModelPath") {% endhighlight %} The `SVMWithSGD.train()` method by default performs L2 regularization with the @@ -304,6 +307,9 @@ public class SVMClassifier { double auROC = metrics.areaUnderROC(); System.out.println("Area under ROC = " + auROC); + + model.save("myModelPath"); + SVMModel sameModel = SVMModel.load("myModelPath"); } } {% endhighlight %} @@ -338,6 +344,8 @@ a dependency. The following example shows how to load a sample dataset, build Logistic Regression model, and make predictions with the resulting model to compute the training error. +Note that the Python API does not yet support model save/load but will in the future. + {% highlight python %} from pyspark.mllib.classification import LogisticRegressionWithSGD from pyspark.mllib.regression import LabeledPoint @@ -391,8 +399,9 @@ values. We compute the mean squared error at the end to evaluate [goodness of fit](http://en.wikipedia.org/wiki/Goodness_of_fit). {% highlight scala %} -import org.apache.spark.mllib.regression.LinearRegressionWithSGD import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.regression.LinearRegressionModel +import org.apache.spark.mllib.regression.LinearRegressionWithSGD import org.apache.spark.mllib.linalg.Vectors // Load and parse the data @@ -413,6 +422,9 @@ val valuesAndPreds = parsedData.map { point => } val MSE = valuesAndPreds.map{case(v, p) => math.pow((v - p), 2)}.mean() println("training Mean Squared Error = " + MSE) + +model.save("myModelPath") +val sameModel = LinearRegressionModel.load("myModelPath") {% endhighlight %} [`RidgeRegressionWithSGD`](api/scala/index.html#org.apache.spark.mllib.regression.RidgeRegressionWithSGD) @@ -483,6 +495,9 @@ public class LinearRegression { } ).rdd()).mean(); System.out.println("training Mean Squared Error = " + MSE); + + model.save("myModelPath"); + LinearRegressionModel sameModel = LinearRegressionModel.load("myModelPath"); } } {% endhighlight %} @@ -494,6 +509,8 @@ The example then uses LinearRegressionWithSGD to build a simple linear model to values. We compute the mean squared error at the end to evaluate [goodness of fit](http://en.wikipedia.org/wiki/Goodness_of_fit). +Note that the Python API does not yet support model save/load but will in the future. + {% highlight python %} from pyspark.mllib.regression import LabeledPoint, LinearRegressionWithSGD from numpy import array diff --git a/docs/mllib-migration-guides.md b/docs/mllib-migration-guides.md new file mode 100644 index 0000000000000..4de2d9491ac2b --- /dev/null +++ b/docs/mllib-migration-guides.md @@ -0,0 +1,67 @@ +--- +layout: global +title: Old Migration Guides - MLlib +displayTitle: MLlib - Old Migration Guides +description: MLlib migration guides from before Spark SPARK_VERSION_SHORT +--- + +The migration guide for the current Spark version is kept on the [MLlib Programming Guide main page](mllib-guide.html#migration-guide). + +## From 1.1 to 1.2 + +The only API changes in MLlib v1.2 are in +[`DecisionTree`](api/scala/index.html#org.apache.spark.mllib.tree.DecisionTree), +which continues to be an experimental API in MLlib 1.2: + +1. *(Breaking change)* The Scala API for classification takes a named argument specifying the number +of classes. In MLlib v1.1, this argument was called `numClasses` in Python and +`numClassesForClassification` in Scala. In MLlib v1.2, the names are both set to `numClasses`. +This `numClasses` parameter is specified either via +[`Strategy`](api/scala/index.html#org.apache.spark.mllib.tree.configuration.Strategy) +or via [`DecisionTree`](api/scala/index.html#org.apache.spark.mllib.tree.DecisionTree) +static `trainClassifier` and `trainRegressor` methods. + +2. *(Breaking change)* The API for +[`Node`](api/scala/index.html#org.apache.spark.mllib.tree.model.Node) has changed. +This should generally not affect user code, unless the user manually constructs decision trees +(instead of using the `trainClassifier` or `trainRegressor` methods). +The tree `Node` now includes more information, including the probability of the predicted label +(for classification). + +3. Printing methods' output has changed. The `toString` (Scala/Java) and `__repr__` (Python) methods used to print the full model; they now print a summary. For the full model, use `toDebugString`. + +Examples in the Spark distribution and examples in the +[Decision Trees Guide](mllib-decision-tree.html#examples) have been updated accordingly. + +## From 1.0 to 1.1 + +The only API changes in MLlib v1.1 are in +[`DecisionTree`](api/scala/index.html#org.apache.spark.mllib.tree.DecisionTree), +which continues to be an experimental API in MLlib 1.1: + +1. *(Breaking change)* The meaning of tree depth has been changed by 1 in order to match +the implementations of trees in +[scikit-learn](http://scikit-learn.org/stable/modules/classes.html#module-sklearn.tree) +and in [rpart](http://cran.r-project.org/web/packages/rpart/index.html). +In MLlib v1.0, a depth-1 tree had 1 leaf node, and a depth-2 tree had 1 root node and 2 leaf nodes. +In MLlib v1.1, a depth-0 tree has 1 leaf node, and a depth-1 tree has 1 root node and 2 leaf nodes. +This depth is specified by the `maxDepth` parameter in +[`Strategy`](api/scala/index.html#org.apache.spark.mllib.tree.configuration.Strategy) +or via [`DecisionTree`](api/scala/index.html#org.apache.spark.mllib.tree.DecisionTree) +static `trainClassifier` and `trainRegressor` methods. + +2. *(Non-breaking change)* We recommend using the newly added `trainClassifier` and `trainRegressor` +methods to build a [`DecisionTree`](api/scala/index.html#org.apache.spark.mllib.tree.DecisionTree), +rather than using the old parameter class `Strategy`. These new training methods explicitly +separate classification and regression, and they replace specialized parameter types with +simple `String` types. + +Examples of the new, recommended `trainClassifier` and `trainRegressor` are given in the +[Decision Trees Guide](mllib-decision-tree.html#examples). + +## From 0.9 to 1.0 + +In MLlib v1.0, we support both dense and sparse input in a unified way, which introduces a few +breaking changes. If your data is sparse, please store it in a sparse format instead of dense to +take advantage of sparsity in both storage and computation. Details are described below. + diff --git a/docs/mllib-naive-bayes.md b/docs/mllib-naive-bayes.md index a71b93fe0daf4..7cbc5825c9127 100644 --- a/docs/mllib-naive-bayes.md +++ b/docs/mllib-naive-bayes.md @@ -40,7 +40,7 @@ smoothing parameter `lambda` as input, an optional model type parameter (default can be used for evaluation and prediction. {% highlight scala %} -import org.apache.spark.mllib.classification.NaiveBayes +import org.apache.spark.mllib.classification.{NaiveBayes, NaiveBayesModel} import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint @@ -58,6 +58,9 @@ val model = NaiveBayes.train(training, lambda = 1.0, model = "Multinomial") val predictionAndLabel = test.map(p => (model.predict(p.features), p.label)) val accuracy = 1.0 * predictionAndLabel.filter(x => x._1 == x._2).count() / test.count() + +model.save("myModelPath") +val sameModel = NaiveBayesModel.load("myModelPath") {% endhighlight %} @@ -96,6 +99,9 @@ double accuracy = predictionAndLabel.filter(new Function, return pl._1().equals(pl._2()); } }).count() / (double) test.count(); + +model.save("myModelPath"); +NaiveBayesModel sameModel = NaiveBayesModel.load("myModelPath"); {% endhighlight %} @@ -108,6 +114,8 @@ smoothing parameter `lambda` as input, and output a [NaiveBayesModel](api/python/pyspark.mllib.classification.NaiveBayesModel-class.html), which can be used for evaluation and prediction. +Note that the Python API does not yet support model save/load but will in the future. + {% highlight python %} from pyspark.mllib.regression import LabeledPoint diff --git a/docs/monitoring.md b/docs/monitoring.md index f32cdef240d31..009a344dff4bb 100644 --- a/docs/monitoring.md +++ b/docs/monitoring.md @@ -1,6 +1,7 @@ --- layout: global title: Monitoring and Instrumentation +description: Monitoring, metrics, and instrumentation guide for Spark SPARK_VERSION_SHORT --- There are several ways to monitor Spark applications: web UIs, metrics, and external instrumentation. @@ -175,6 +176,7 @@ Each instance can report to zero or more _sinks_. Sinks are contained in the * `JmxSink`: Registers metrics for viewing in a JMX console. * `MetricsServlet`: Adds a servlet within the existing Spark UI to serve metrics data as JSON data. * `GraphiteSink`: Sends metrics to a Graphite node. +* `Slf4jSink`: Sends metrics to slf4j as log entries. Spark also supports a Ganglia sink which is not included in the default build due to licensing restrictions: diff --git a/docs/programming-guide.md b/docs/programming-guide.md index 5e0d5c15d7069..7b0701828878e 100644 --- a/docs/programming-guide.md +++ b/docs/programming-guide.md @@ -1,6 +1,7 @@ --- layout: global title: Spark Programming Guide +description: Spark SPARK_VERSION_SHORT programming guide in Java, Scala and Python --- * This will become a table of contents (this text will be scraped). @@ -172,8 +173,11 @@ in-process. In the Spark shell, a special interpreter-aware SparkContext is already created for you, in the variable called `sc`. Making your own SparkContext will not work. You can set which master the context connects to using the `--master` argument, and you can add JARs to the classpath -by passing a comma-separated list to the `--jars` argument. -For example, to run `bin/spark-shell` on exactly four cores, use: +by passing a comma-separated list to the `--jars` argument. You can also add dependencies +(e.g. Spark Packages) to your shell session by supplying a comma-separated list of maven coordinates +to the `--packages` argument. Any additional repositories where dependencies might exist (e.g. SonaType) +can be passed to the `--repositories` argument. For example, to run `bin/spark-shell` on exactly +four cores, use: {% highlight bash %} $ ./bin/spark-shell --master local[4] @@ -185,6 +189,12 @@ Or, to also add `code.jar` to its classpath, use: $ ./bin/spark-shell --master local[4] --jars code.jar {% endhighlight %} +To include a dependency using maven coordinates: + +{% highlight bash %} +$ ./bin/spark-shell --master local[4] --packages "org.example:example:0.1" +{% endhighlight %} + For a complete list of options, run `spark-shell --help`. Behind the scenes, `spark-shell` invokes the more general [`spark-submit` script](submitting-applications.html). @@ -195,7 +205,11 @@ For a complete list of options, run `spark-shell --help`. Behind the scenes, In the PySpark shell, a special interpreter-aware SparkContext is already created for you, in the variable called `sc`. Making your own SparkContext will not work. You can set which master the context connects to using the `--master` argument, and you can add Python .zip, .egg or .py files -to the runtime path by passing a comma-separated list to `--py-files`. +to the runtime path by passing a comma-separated list to `--py-files`. You can also add dependencies +(e.g. Spark Packages) to your shell session by supplying a comma-separated list of maven coordinates +to the `--packages` argument. Any additional repositories where dependencies might exist (e.g. SonaType) +can be passed to the `--repositories` argument. Any python dependencies a Spark Package has (listed in +the requirements.txt of that package) must be manually installed using pip when necessary. For example, to run `bin/pyspark` on exactly four cores, use: {% highlight bash %} @@ -321,7 +335,7 @@ Apart from text files, Spark's Scala API also supports several other data format * For [SequenceFiles](http://hadoop.apache.org/common/docs/current/api/org/apache/hadoop/mapred/SequenceFileInputFormat.html), use SparkContext's `sequenceFile[K, V]` method where `K` and `V` are the types of key and values in the file. These should be subclasses of Hadoop's [Writable](http://hadoop.apache.org/common/docs/current/api/org/apache/hadoop/io/Writable.html) interface, like [IntWritable](http://hadoop.apache.org/common/docs/current/api/org/apache/hadoop/io/IntWritable.html) and [Text](http://hadoop.apache.org/common/docs/current/api/org/apache/hadoop/io/Text.html). In addition, Spark allows you to specify native types for a few common Writables; for example, `sequenceFile[Int, String]` will automatically read IntWritables and Texts. -* For other Hadoop InputFormats, you can use the `SparkContext.hadoopRDD` method, which takes an arbitrary `JobConf` and input format class, key class and value class. Set these the same way you would for a Hadoop job with your input source. You can also use `SparkContext.newHadoopRDD` for InputFormats based on the "new" MapReduce API (`org.apache.hadoop.mapreduce`). +* For other Hadoop InputFormats, you can use the `SparkContext.hadoopRDD` method, which takes an arbitrary `JobConf` and input format class, key class and value class. Set these the same way you would for a Hadoop job with your input source. You can also use `SparkContext.newAPIHadoopRDD` for InputFormats based on the "new" MapReduce API (`org.apache.hadoop.mapreduce`). * `RDD.saveAsObjectFile` and `SparkContext.objectFile` support saving an RDD in a simple format consisting of serialized Java objects. While this is not as efficient as specialized formats like Avro, it offers an easy way to save any RDD. @@ -353,7 +367,7 @@ Apart from text files, Spark's Java API also supports several other data formats * For [SequenceFiles](http://hadoop.apache.org/common/docs/current/api/org/apache/hadoop/mapred/SequenceFileInputFormat.html), use SparkContext's `sequenceFile[K, V]` method where `K` and `V` are the types of key and values in the file. These should be subclasses of Hadoop's [Writable](http://hadoop.apache.org/common/docs/current/api/org/apache/hadoop/io/Writable.html) interface, like [IntWritable](http://hadoop.apache.org/common/docs/current/api/org/apache/hadoop/io/IntWritable.html) and [Text](http://hadoop.apache.org/common/docs/current/api/org/apache/hadoop/io/Text.html). -* For other Hadoop InputFormats, you can use the `JavaSparkContext.hadoopRDD` method, which takes an arbitrary `JobConf` and input format class, key class and value class. Set these the same way you would for a Hadoop job with your input source. You can also use `JavaSparkContext.newHadoopRDD` for InputFormats based on the "new" MapReduce API (`org.apache.hadoop.mapreduce`). +* For other Hadoop InputFormats, you can use the `JavaSparkContext.hadoopRDD` method, which takes an arbitrary `JobConf` and input format class, key class and value class. Set these the same way you would for a Hadoop job with your input source. You can also use `JavaSparkContext.newAPIHadoopRDD` for InputFormats based on the "new" MapReduce API (`org.apache.hadoop.mapreduce`). * `JavaRDD.saveAsObjectFile` and `JavaSparkContext.objectFile` support saving an RDD in a simple format consisting of serialized Java objects. While this is not as efficient as specialized formats like Avro, it offers an easy way to save any RDD. @@ -886,7 +900,7 @@ for details. groupByKey([numTasks]) When called on a dataset of (K, V) pairs, returns a dataset of (K, Iterable<V>) pairs.
Note: If you are grouping in order to perform an aggregation (such as a sum or - average) over each key, using reduceByKey or combineByKey will yield much better + average) over each key, using reduceByKey or aggregateByKey will yield much better performance.
Note: By default, the level of parallelism in the output depends on the number of partitions of the parent RDD. @@ -913,7 +927,7 @@ for details. cogroup(otherDataset, [numTasks]) - When called on datasets of type (K, V) and (K, W), returns a dataset of (K, Iterable<V>, Iterable<W>) tuples. This operation is also called groupWith. + When called on datasets of type (K, V) and (K, W), returns a dataset of (K, (Iterable<V>, Iterable<W>)) tuples. This operation is also called groupWith. cartesian(otherDataset) @@ -974,7 +988,7 @@ for details. take(n) - Return an array with the first n elements of the dataset. Note that this is currently not executed in parallel. Instead, the driver program computes all the elements. + Return an array with the first n elements of the dataset. takeSample(withReplacement, num, [seed]) @@ -1316,7 +1330,35 @@ For accumulator updates performed inside actions only, Spark guarantees t will only be applied once, i.e. restarted tasks will not update the value. In transformations, users should be aware of that each task's update may be applied more than once if tasks or job stages are re-executed. +Accumulators do not change the lazy evaluation model of Spark. If they are being updated within an operation on an RDD, their value is only updated once that RDD is computed as part of an action. Consequently, accumulator updates are not guaranteed to be executed when made within a lazy transformation like `map()`. The below code fragment demonstrates this property: + +
+ +
+{% highlight scala %} +val acc = sc.accumulator(0) +data.map(x => acc += x; f(x)) +// Here, acc is still 0 because no actions have cause the `map` to be computed. +{% endhighlight %} +
+
+{% highlight java %} +Accumulator accum = sc.accumulator(0); +data.map(x -> accum.add(x); f(x);); +// Here, accum is still 0 because no actions have cause the `map` to be computed. +{% endhighlight %} +
+ +
+{% highlight python %} +accum = sc.accumulator(0) +data.map(lambda x => acc.add(x); f(x)) +# Here, acc is still 0 because no actions have cause the `map` to be computed. +{% endhighlight %} +
+ +
# Deploying to a Cluster diff --git a/docs/quick-start.md b/docs/quick-start.md index bf643bb70e153..81143da865cf0 100644 --- a/docs/quick-start.md +++ b/docs/quick-start.md @@ -1,6 +1,7 @@ --- layout: global title: Quick Start +description: Quick start tutorial for Spark SPARK_VERSION_SHORT --- * This will become a table of contents (this text will be scraped). diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md index 78358499fd01f..db1173a06b0b1 100644 --- a/docs/running-on-mesos.md +++ b/docs/running-on-mesos.md @@ -197,7 +197,11 @@ See the [configuration page](configuration.html) for information on Spark config spark.mesos.coarse false - Set the run mode for Spark on Mesos. For more information about the run mode, refer to #Mesos Run Mode section above. + If set to "true", runs over Mesos clusters in + "coarse-grained" sharing mode, + where Spark acquires one long-lived Mesos task on each machine instead of one Mesos task per + Spark task. This gives lower-latency scheduling for short queries, but leaves resources in use + for the whole duration of the Spark job. @@ -211,19 +215,23 @@ See the [configuration page](configuration.html) for information on Spark config spark.mesos.executor.home - SPARK_HOME + driver side SPARK_HOME - The location where the mesos executor will look for Spark binaries to execute, and uses the SPARK_HOME setting on default. - This variable is only used when no spark.executor.uri is provided, and assumes Spark is installed on the specified location - on each slave. + Set the directory in which Spark is installed on the executors in Mesos. By default, the + executors will simply use the driver's Spark home directory, which may not be visible to + them. Note that this is only relevant if a Spark binary package is not specified through + spark.executor.uri. spark.mesos.executor.memoryOverhead - 384 + executor memory * 0.07, with minimum of 384 - The amount of memory that Mesos executor will request for the task to account for the overhead of running the executor itself. - The final total amount of memory allocated is the maximum value between executor memory plus memoryOverhead, and overhead fraction (1.07) plus the executor memory. + This value is an additive for spark.executor.memory, specified in MiB, + which is used to calculate the total Mesos task memory. A value of 384 + implies a 384MiB overhead. Additionally, there is a hard-coded 7% minimum + overhead. The final overhead will be the larger of either + `spark.mesos.executor.memoryOverhead` or 7% of `spark.executor.memory`. diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index 4f273098c5db3..2b93eef6c26ed 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -29,6 +29,23 @@ Most of the configs are the same for Spark on YARN as for other deployment modes In cluster mode, use spark.driver.memory instead. + + spark.driver.cores + 1 + + Number of cores used by the driver in YARN cluster mode. + Since the driver is run in the same JVM as the YARN Application Master in cluster mode, this also controls the cores used by the YARN AM. + In client mode, use spark.yarn.am.cores to control the number of cores used by the YARN AM instead. + + + + spark.yarn.am.cores + 1 + + Number of cores to use for the YARN Application Master in client mode. + In cluster mode, use spark.driver.cores instead. + + spark.yarn.am.waitTime 100000 @@ -87,6 +104,13 @@ Most of the configs are the same for Spark on YARN as for other deployment modes Comma-separated list of files to be placed in the working directory of each executor. + + spark.executor.instances + 2 + + The number of executors. Note that this property is incompatible with spark.dynamicAllocation.enabled. + + spark.yarn.executor.memoryOverhead executorMemory * 0.07, with minimum of 384 diff --git a/docs/security.md b/docs/security.md index 1e206a139fb72..c034ba12ff1fc 100644 --- a/docs/security.md +++ b/docs/security.md @@ -1,6 +1,7 @@ --- layout: global -title: Spark Security +displayTitle: Spark Security +title: Security --- Spark currently supports authentication via a shared secret. Authentication can be configured to be on via the `spark.authenticate` configuration parameter. This parameter controls whether the Spark communication protocols do authentication using the shared secret. This authentication is a basic handshake to make sure both sides have the same shared secret and are allowed to communicate. If the shared secret is not identical they will not be allowed to communicate. The shared secret is created as follows: @@ -20,6 +21,30 @@ Spark allows for a set of administrators to be specified in the acls who always If your applications are using event logging, the directory where the event logs go (`spark.eventLog.dir`) should be manually created and have the proper permissions set on it. If you want those log files secured, the permissions should be set to `drwxrwxrwxt` for that directory. The owner of the directory should be the super user who is running the history server and the group permissions should be restricted to super user group. This will allow all users to write to the directory but will prevent unprivileged users from removing or renaming a file unless they own the file or directory. The event log files will be created by Spark with permissions such that only the user and group have read and write access. +## Encryption + +Spark supports SSL for Akka and HTTP (for broadcast and file server) protocols. However SSL is not supported yet for WebUI and block transfer service. + +Connection encryption (SSL) configuration is organized hierarchically. The user can configure the default SSL settings which will be used for all the supported communication protocols unless they are overwritten by protocol-specific settings. This way the user can easily provide the common settings for all the protocols without disabling the ability to configure each one individually. The common SSL settings are at `spark.ssl` namespace in Spark configuration, while Akka SSL configuration is at `spark.ssl.akka` and HTTP for broadcast and file server SSL configuration is at `spark.ssl.fs`. The full breakdown can be found on the [configuration page](configuration.html). + +SSL must be configured on each node and configured for each component involved in communication using the particular protocol. + +### YARN mode +The key-store can be prepared on the client side and then distributed and used by the executors as the part of the application. It is possible because the user is able to deploy files before the application is started in YARN by using `spark.yarn.dist.files` or `spark.yarn.dist.archives` configuration settings. The responsibility for encryption of transferring these files is on YARN side and has nothing to do with Spark. + +### Standalone mode +The user needs to provide key-stores and configuration options for master and workers. They have to be set by attaching appropriate Java system properties in `SPARK_MASTER_OPTS` and in `SPARK_WORKER_OPTS` environment variables, or just in `SPARK_DAEMON_JAVA_OPTS`. In this mode, the user may allow the executors to use the SSL settings inherited from the worker which spawned that executor. It can be accomplished by setting `spark.ssl.useNodeLocalConf` to `true`. If that parameter is set, the settings provided by user on the client side, are not used by the executors. + +### Preparing the key-stores +Key-stores can be generated by `keytool` program. The reference documentation for this tool is +[here](https://docs.oracle.com/javase/7/docs/technotes/tools/solaris/keytool.html). The most basic +steps to configure the key-stores and the trust-store for the standalone deployment mode is as +follows: +* Generate a keys pair for each node +* Export the public key of the key pair to a file on each node +* Import all exported public keys into a single trust-store +* Distribute the trust-store over the nodes + ## Configuring Ports for Network Security Spark makes heavy use of the network, and some environments have strict requirements for using tight diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 729045b81a8c0..0146a4ed1b745 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1,6 +1,7 @@ --- layout: global -title: Spark SQL Programming Guide +displayTitle: Spark SQL Programming Guide +title: Spark SQL --- * This will become a table of contents (this text will be scraped). @@ -13,10 +14,10 @@ title: Spark SQL Programming Guide Spark SQL allows relational queries expressed in SQL, HiveQL, or Scala to be executed using Spark. At the core of this component is a new type of RDD, -[SchemaRDD](api/scala/index.html#org.apache.spark.sql.SchemaRDD). SchemaRDDs are composed of +[DataFrame](api/scala/index.html#org.apache.spark.sql.DataFrame). DataFrames are composed of [Row](api/scala/index.html#org.apache.spark.sql.package@Row:org.apache.spark.sql.catalyst.expressions.Row.type) objects, along with -a schema that describes the data types of each column in the row. A SchemaRDD is similar to a table -in a traditional relational database. A SchemaRDD can be created from an existing RDD, a [Parquet](http://parquet.io) +a schema that describes the data types of each column in the row. A DataFrame is similar to a table +in a traditional relational database. A DataFrame can be created from an existing RDD, a [Parquet](http://parquet.io) file, a JSON dataset, or by running HiveQL against data stored in [Apache Hive](http://hive.apache.org/). All of the examples on this page use sample data included in the Spark distribution and can be run in the `spark-shell`. @@ -26,10 +27,10 @@ All of the examples on this page use sample data included in the Spark distribut
Spark SQL allows relational queries expressed in SQL or HiveQL to be executed using Spark. At the core of this component is a new type of RDD, -[JavaSchemaRDD](api/scala/index.html#org.apache.spark.sql.api.java.JavaSchemaRDD). JavaSchemaRDDs are composed of +[DataFrame](api/scala/index.html#org.apache.spark.sql.DataFrame). DataFrames are composed of [Row](api/scala/index.html#org.apache.spark.sql.api.java.Row) objects, along with -a schema that describes the data types of each column in the row. A JavaSchemaRDD is similar to a table -in a traditional relational database. A JavaSchemaRDD can be created from an existing RDD, a [Parquet](http://parquet.io) +a schema that describes the data types of each column in the row. A DataFrame is similar to a table +in a traditional relational database. A DataFrame can be created from an existing RDD, a [Parquet](http://parquet.io) file, a JSON dataset, or by running HiveQL against data stored in [Apache Hive](http://hive.apache.org/).
@@ -37,10 +38,10 @@ file, a JSON dataset, or by running HiveQL against data stored in [Apache Hive]( Spark SQL allows relational queries expressed in SQL or HiveQL to be executed using Spark. At the core of this component is a new type of RDD, -[SchemaRDD](api/python/pyspark.sql.SchemaRDD-class.html). SchemaRDDs are composed of +[DataFrame](api/python/pyspark.sql.html#pyspark.sql.DataFrame). DataFrames are composed of [Row](api/python/pyspark.sql.Row-class.html) objects, along with -a schema that describes the data types of each column in the row. A SchemaRDD is similar to a table -in a traditional relational database. A SchemaRDD can be created from an existing RDD, a [Parquet](http://parquet.io) +a schema that describes the data types of each column in the row. A DataFrame is similar to a table +in a traditional relational database. A DataFrame can be created from an existing RDD, a [Parquet](http://parquet.io) file, a JSON dataset, or by running HiveQL against data stored in [Apache Hive](http://hive.apache.org/). All of the examples on this page use sample data included in the Spark distribution and can be run in the `pyspark` shell. @@ -64,8 +65,8 @@ descendants. To create a basic SQLContext, all you need is a SparkContext. val sc: SparkContext // An existing SparkContext. val sqlContext = new org.apache.spark.sql.SQLContext(sc) -// createSchemaRDD is used to implicitly convert an RDD to a SchemaRDD. -import sqlContext.createSchemaRDD +// this is used to implicitly convert an RDD to a DataFrame. +import sqlContext.implicits._ {% endhighlight %} In addition to the basic SQLContext, you can also create a HiveContext, which provides a @@ -83,12 +84,12 @@ feature parity with a HiveContext.
The entry point into all relational functionality in Spark is the -[JavaSQLContext](api/scala/index.html#org.apache.spark.sql.api.java.JavaSQLContext) class, or one -of its descendants. To create a basic JavaSQLContext, all you need is a JavaSparkContext. +[SQLContext](api/scala/index.html#org.apache.spark.sql.api.SQLContext) class, or one +of its descendants. To create a basic SQLContext, all you need is a JavaSparkContext. {% highlight java %} JavaSparkContext sc = ...; // An existing JavaSparkContext. -JavaSQLContext sqlContext = new org.apache.spark.sql.api.java.JavaSQLContext(sc); +SQLContext sqlContext = new org.apache.spark.sql.SQLContext(sc); {% endhighlight %} In addition to the basic SQLContext, you can also create a HiveContext, which provides a strict @@ -137,39 +138,39 @@ default is "hiveql", though "sql" is also available. Since the HiveQL parser is # Data Sources -Spark SQL supports operating on a variety of data sources through the `SchemaRDD` interface. -A SchemaRDD can be operated on as normal RDDs and can also be registered as a temporary table. -Registering a SchemaRDD as a table allows you to run SQL queries over its data. This section -describes the various methods for loading data into a SchemaRDD. +Spark SQL supports operating on a variety of data sources through the `DataFrame` interface. +A DataFrame can be operated on as normal RDDs and can also be registered as a temporary table. +Registering a DataFrame as a table allows you to run SQL queries over its data. This section +describes the various methods for loading data into a DataFrame. ## RDDs -Spark SQL supports two different methods for converting existing RDDs into SchemaRDDs. The first +Spark SQL supports two different methods for converting existing RDDs into DataFrames. The first method uses reflection to infer the schema of an RDD that contains specific types of objects. This reflection based approach leads to more concise code and works well when you already know the schema while writing your Spark application. -The second method for creating SchemaRDDs is through a programmatic interface that allows you to +The second method for creating DataFrames is through a programmatic interface that allows you to construct a schema and then apply it to an existing RDD. While this method is more verbose, it allows -you to construct SchemaRDDs when the columns and their types are not known until runtime. +you to construct DataFrames when the columns and their types are not known until runtime. ### Inferring the Schema Using Reflection
-The Scala interaface for Spark SQL supports automatically converting an RDD containing case classes -to a SchemaRDD. The case class +The Scala interface for Spark SQL supports automatically converting an RDD containing case classes +to a DataFrame. The case class defines the schema of the table. The names of the arguments to the case class are read using reflection and become the names of the columns. Case classes can also be nested or contain complex -types such as Sequences or Arrays. This RDD can be implicitly converted to a SchemaRDD and then be +types such as Sequences or Arrays. This RDD can be implicitly converted to a DataFrame and then be registered as a table. Tables can be used in subsequent SQL statements. {% highlight scala %} // sc is an existing SparkContext. val sqlContext = new org.apache.spark.sql.SQLContext(sc) -// createSchemaRDD is used to implicitly convert an RDD to a SchemaRDD. -import sqlContext.createSchemaRDD +// this is used to implicitly convert an RDD to a DataFrame. +import sqlContext.implicits._ // Define the schema using a case class. // Note: Case classes in Scala 2.10 can support only up to 22 fields. To work around this limit, @@ -183,7 +184,7 @@ people.registerTempTable("people") // SQL statements can be run by using the sql methods provided by sqlContext. val teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19") -// The results of SQL queries are SchemaRDDs and support all the normal RDD operations. +// The results of SQL queries are DataFrames and support all the normal RDD operations. // The columns of a row in the result can be accessed by ordinal. teenagers.map(t => "Name: " + t(0)).collect().foreach(println) {% endhighlight %} @@ -193,7 +194,7 @@ teenagers.map(t => "Name: " + t(0)).collect().foreach(println)
Spark SQL supports automatically converting an RDD of [JavaBeans](http://stackoverflow.com/questions/3295496/what-is-a-javabean-exactly) -into a Schema RDD. The BeanInfo, obtained using reflection, defines the schema of the table. +into a DataFrame. The BeanInfo, obtained using reflection, defines the schema of the table. Currently, Spark SQL does not support JavaBeans that contain nested or contain complex types such as Lists or Arrays. You can create a JavaBean by creating a class that implements Serializable and has getters and setters for all of its fields. @@ -224,12 +225,12 @@ public static class Person implements Serializable { {% endhighlight %} -A schema can be applied to an existing RDD by calling `applySchema` and providing the Class object +A schema can be applied to an existing RDD by calling `createDataFrame` and providing the Class object for the JavaBean. {% highlight java %} // sc is an existing JavaSparkContext. -JavaSQLContext sqlContext = new org.apache.spark.sql.api.java.JavaSQLContext(sc); +SQLContext sqlContext = new org.apache.spark.sql.SQLContext(sc); // Load a text file and convert each line to a JavaBean. JavaRDD people = sc.textFile("examples/src/main/resources/people.txt").map( @@ -246,13 +247,13 @@ JavaRDD people = sc.textFile("examples/src/main/resources/people.txt").m }); // Apply a schema to an RDD of JavaBeans and register it as a table. -JavaSchemaRDD schemaPeople = sqlContext.applySchema(people, Person.class); +DataFrame schemaPeople = sqlContext.createDataFrame(people, Person.class); schemaPeople.registerTempTable("people"); // SQL can be run over RDDs that have been registered as tables. -JavaSchemaRDD teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19") +DataFrame teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19") -// The results of SQL queries are SchemaRDDs and support all the normal RDD operations. +// The results of SQL queries are DataFrames and support all the normal RDD operations. // The columns of a row in the result can be accessed by ordinal. List teenagerNames = teenagers.map(new Function() { public String call(Row row) { @@ -266,7 +267,7 @@ List teenagerNames = teenagers.map(new Function() {
-Spark SQL can convert an RDD of Row objects to a SchemaRDD, inferring the datatypes. Rows are constructed by passing a list of +Spark SQL can convert an RDD of Row objects to a DataFrame, inferring the datatypes. Rows are constructed by passing a list of key/value pairs as kwargs to the Row class. The keys of this list define the column names of the table, and the types are inferred by looking at the first row. Since we currently only look at the first row, it is important that there is no missing data in the first row of the RDD. In future versions we @@ -283,11 +284,11 @@ lines = sc.textFile("examples/src/main/resources/people.txt") parts = lines.map(lambda l: l.split(",")) people = parts.map(lambda p: Row(name=p[0], age=int(p[1]))) -# Infer the schema, and register the SchemaRDD as a table. +# Infer the schema, and register the DataFrame as a table. schemaPeople = sqlContext.inferSchema(people) schemaPeople.registerTempTable("people") -# SQL can be run over SchemaRDDs that have been registered as a table. +# SQL can be run over DataFrames that have been registered as a table. teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19") # The results of SQL queries are RDDs and support all the normal RDD operations. @@ -309,12 +310,12 @@ for teenName in teenNames.collect(): When case classes cannot be defined ahead of time (for example, the structure of records is encoded in a string, or a text dataset will be parsed and fields will be projected differently for different users), -a `SchemaRDD` can be created programmatically with three steps. +a `DataFrame` can be created programmatically with three steps. 1. Create an RDD of `Row`s from the original RDD; 2. Create the schema represented by a `StructType` matching the structure of `Row`s in the RDD created in Step 1. -3. Apply the schema to the RDD of `Row`s via `applySchema` method provided +3. Apply the schema to the RDD of `Row`s via `createDataFrame` method provided by `SQLContext`. For example: @@ -340,15 +341,15 @@ val schema = val rowRDD = people.map(_.split(",")).map(p => Row(p(0), p(1).trim)) // Apply the schema to the RDD. -val peopleSchemaRDD = sqlContext.applySchema(rowRDD, schema) +val peopleDataFrame = sqlContext.createDataFrame(rowRDD, schema) -// Register the SchemaRDD as a table. -peopleSchemaRDD.registerTempTable("people") +// Register the DataFrames as a table. +peopleDataFrame.registerTempTable("people") // SQL statements can be run by using the sql methods provided by sqlContext. val results = sqlContext.sql("SELECT name FROM people") -// The results of SQL queries are SchemaRDDs and support all the normal RDD operations. +// The results of SQL queries are DataFrames and support all the normal RDD operations. // The columns of a row in the result can be accessed by ordinal. results.map(t => "Name: " + t(0)).collect().foreach(println) {% endhighlight %} @@ -361,13 +362,13 @@ results.map(t => "Name: " + t(0)).collect().foreach(println) When JavaBean classes cannot be defined ahead of time (for example, the structure of records is encoded in a string, or a text dataset will be parsed and fields will be projected differently for different users), -a `SchemaRDD` can be created programmatically with three steps. +a `DataFrame` can be created programmatically with three steps. 1. Create an RDD of `Row`s from the original RDD; 2. Create the schema represented by a `StructType` matching the structure of `Row`s in the RDD created in Step 1. -3. Apply the schema to the RDD of `Row`s via `applySchema` method provided -by `JavaSQLContext`. +3. Apply the schema to the RDD of `Row`s via `createDataFrame` method provided +by `SQLContext`. For example: {% highlight java %} @@ -380,7 +381,7 @@ import org.apache.spark.sql.api.java.StructField import org.apache.spark.sql.api.java.Row // sc is an existing JavaSparkContext. -JavaSQLContext sqlContext = new org.apache.spark.sql.api.java.JavaSQLContext(sc); +SQLContext sqlContext = new org.apache.spark.sql.SQLContext(sc); // Load a text file and convert each line to a JavaBean. JavaRDD people = sc.textFile("examples/src/main/resources/people.txt"); @@ -405,15 +406,15 @@ JavaRDD rowRDD = people.map( }); // Apply the schema to the RDD. -JavaSchemaRDD peopleSchemaRDD = sqlContext.applySchema(rowRDD, schema); +DataFrame peopleDataFrame = sqlContext.createDataFrame(rowRDD, schema); -// Register the SchemaRDD as a table. -peopleSchemaRDD.registerTempTable("people"); +// Register the DataFrame as a table. +peopleDataFrame.registerTempTable("people"); // SQL can be run over RDDs that have been registered as tables. -JavaSchemaRDD results = sqlContext.sql("SELECT name FROM people"); +DataFrame results = sqlContext.sql("SELECT name FROM people"); -// The results of SQL queries are SchemaRDDs and support all the normal RDD operations. +// The results of SQL queries are DataFrames and support all the normal RDD operations. // The columns of a row in the result can be accessed by ordinal. List names = results.map(new Function() { public String call(Row row) { @@ -430,12 +431,12 @@ List names = results.map(new Function() { When a dictionary of kwargs cannot be defined ahead of time (for example, the structure of records is encoded in a string, or a text dataset will be parsed and fields will be projected differently for different users), -a `SchemaRDD` can be created programmatically with three steps. +a `DataFrame` can be created programmatically with three steps. 1. Create an RDD of tuples or lists from the original RDD; 2. Create the schema represented by a `StructType` matching the structure of tuples or lists in the RDD created in the step 1. -3. Apply the schema to the RDD via `applySchema` method provided by `SQLContext`. +3. Apply the schema to the RDD via `createDataFrame` method provided by `SQLContext`. For example: {% highlight python %} @@ -457,12 +458,12 @@ fields = [StructField(field_name, StringType(), True) for field_name in schemaSt schema = StructType(fields) # Apply the schema to the RDD. -schemaPeople = sqlContext.applySchema(people, schema) +schemaPeople = sqlContext.createDataFrame(people, schema) -# Register the SchemaRDD as a table. +# Register the DataFrame as a table. schemaPeople.registerTempTable("people") -# SQL can be run over SchemaRDDs that have been registered as a table. +# SQL can be run over DataFrames that have been registered as a table. results = sqlContext.sql("SELECT name FROM people") # The results of SQL queries are RDDs and support all the normal RDD operations. @@ -492,16 +493,16 @@ Using the data from the above example: {% highlight scala %} // sqlContext from the previous example is used in this example. -// createSchemaRDD is used to implicitly convert an RDD to a SchemaRDD. -import sqlContext.createSchemaRDD +// This is used to implicitly convert an RDD to a DataFrame. +import sqlContext.implicits._ val people: RDD[Person] = ... // An RDD of case class objects, from the previous example. -// The RDD is implicitly converted to a SchemaRDD by createSchemaRDD, allowing it to be stored using Parquet. +// The RDD is implicitly converted to a DataFrame by implicits, allowing it to be stored using Parquet. people.saveAsParquetFile("people.parquet") // Read in the parquet file created above. Parquet files are self-describing so the schema is preserved. -// The result of loading a Parquet file is also a SchemaRDD. +// The result of loading a Parquet file is also a DataFrame. val parquetFile = sqlContext.parquetFile("people.parquet") //Parquet files can also be registered as tables and then used in SQL statements. @@ -517,18 +518,18 @@ teenagers.map(t => "Name: " + t(0)).collect().foreach(println) {% highlight java %} // sqlContext from the previous example is used in this example. -JavaSchemaRDD schemaPeople = ... // The JavaSchemaRDD from the previous example. +DataFrame schemaPeople = ... // The DataFrame from the previous example. -// JavaSchemaRDDs can be saved as Parquet files, maintaining the schema information. +// DataFrames can be saved as Parquet files, maintaining the schema information. schemaPeople.saveAsParquetFile("people.parquet"); // Read in the Parquet file created above. Parquet files are self-describing so the schema is preserved. -// The result of loading a parquet file is also a JavaSchemaRDD. -JavaSchemaRDD parquetFile = sqlContext.parquetFile("people.parquet"); +// The result of loading a parquet file is also a DataFrame. +DataFrame parquetFile = sqlContext.parquetFile("people.parquet"); //Parquet files can also be registered as tables and then used in SQL statements. parquetFile.registerTempTable("parquetFile"); -JavaSchemaRDD teenagers = sqlContext.sql("SELECT name FROM parquetFile WHERE age >= 13 AND age <= 19"); +DataFrame teenagers = sqlContext.sql("SELECT name FROM parquetFile WHERE age >= 13 AND age <= 19"); List teenagerNames = teenagers.map(new Function() { public String call(Row row) { return "Name: " + row.getString(0); @@ -543,13 +544,13 @@ List teenagerNames = teenagers.map(new Function() { {% highlight python %} # sqlContext from the previous example is used in this example. -schemaPeople # The SchemaRDD from the previous example. +schemaPeople # The DataFrame from the previous example. -# SchemaRDDs can be saved as Parquet files, maintaining the schema information. +# DataFrames can be saved as Parquet files, maintaining the schema information. schemaPeople.saveAsParquetFile("people.parquet") # Read in the Parquet file created above. Parquet files are self-describing so the schema is preserved. -# The result of loading a parquet file is also a SchemaRDD. +# The result of loading a parquet file is also a DataFrame. parquetFile = sqlContext.parquetFile("people.parquet") # Parquet files can also be registered as tables and then used in SQL statements. @@ -580,6 +581,15 @@ Configuration of Parquet can be done using the `setConf` method on SQLContext or flag tells Spark SQL to interpret binary data as a string to provide compatibility with these systems. + + spark.sql.parquet.int96AsTimestamp + true + + Some Parquet-producing systems, in particular Impala, store Timestamp into INT96. Spark would also + store Timestamp as INT96 because we need to avoid precision lost of the nanoseconds field. This + flag tells Spark SQL to interpret INT96 data as a timestamp to provide compatibility with these systems. + + spark.sql.parquet.cacheMetadata true @@ -619,7 +629,7 @@ Configuration of Parquet can be done using the `setConf` method on SQLContext or
-Spark SQL can automatically infer the schema of a JSON dataset and load it as a SchemaRDD. +Spark SQL can automatically infer the schema of a JSON dataset and load it as a DataFrame. This conversion can be done using one of two methods in a SQLContext: * `jsonFile` - loads data from a directory of JSON files where each line of the files is a JSON object. @@ -636,7 +646,7 @@ val sqlContext = new org.apache.spark.sql.SQLContext(sc) // A JSON dataset is pointed to by path. // The path can be either a single text file or a directory storing text files. val path = "examples/src/main/resources/people.json" -// Create a SchemaRDD from the file(s) pointed to by path +// Create a DataFrame from the file(s) pointed to by path val people = sqlContext.jsonFile(path) // The inferred schema can be visualized using the printSchema() method. @@ -645,13 +655,13 @@ people.printSchema() // |-- age: integer (nullable = true) // |-- name: string (nullable = true) -// Register this SchemaRDD as a table. +// Register this DataFrame as a table. people.registerTempTable("people") // SQL statements can be run by using the sql methods provided by sqlContext. val teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19") -// Alternatively, a SchemaRDD can be created for a JSON dataset represented by +// Alternatively, a DataFrame can be created for a JSON dataset represented by // an RDD[String] storing one JSON object per string. val anotherPeopleRDD = sc.parallelize( """{"name":"Yin","address":{"city":"Columbus","state":"Ohio"}}""" :: Nil) @@ -661,8 +671,8 @@ val anotherPeople = sqlContext.jsonRDD(anotherPeopleRDD)
-Spark SQL can automatically infer the schema of a JSON dataset and load it as a JavaSchemaRDD. -This conversion can be done using one of two methods in a JavaSQLContext : +Spark SQL can automatically infer the schema of a JSON dataset and load it as a DataFrame. +This conversion can be done using one of two methods in a SQLContext : * `jsonFile` - loads data from a directory of JSON files where each line of the files is a JSON object. * `jsonRDD` - loads data from an existing RDD where each element of the RDD is a string containing a JSON object. @@ -673,13 +683,13 @@ a regular multi-line JSON file will most often fail. {% highlight java %} // sc is an existing JavaSparkContext. -JavaSQLContext sqlContext = new org.apache.spark.sql.api.java.JavaSQLContext(sc); +SQLContext sqlContext = new org.apache.spark.sql.SQLContext(sc); // A JSON dataset is pointed to by path. // The path can be either a single text file or a directory storing text files. String path = "examples/src/main/resources/people.json"; -// Create a JavaSchemaRDD from the file(s) pointed to by path -JavaSchemaRDD people = sqlContext.jsonFile(path); +// Create a DataFrame from the file(s) pointed to by path +DataFrame people = sqlContext.jsonFile(path); // The inferred schema can be visualized using the printSchema() method. people.printSchema(); @@ -687,23 +697,23 @@ people.printSchema(); // |-- age: integer (nullable = true) // |-- name: string (nullable = true) -// Register this JavaSchemaRDD as a table. +// Register this DataFrame as a table. people.registerTempTable("people"); // SQL statements can be run by using the sql methods provided by sqlContext. -JavaSchemaRDD teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19"); +DataFrame teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19"); -// Alternatively, a JavaSchemaRDD can be created for a JSON dataset represented by +// Alternatively, a DataFrame can be created for a JSON dataset represented by // an RDD[String] storing one JSON object per string. List jsonData = Arrays.asList( "{\"name\":\"Yin\",\"address\":{\"city\":\"Columbus\",\"state\":\"Ohio\"}}"); JavaRDD anotherPeopleRDD = sc.parallelize(jsonData); -JavaSchemaRDD anotherPeople = sqlContext.jsonRDD(anotherPeopleRDD); +DataFrame anotherPeople = sqlContext.jsonRDD(anotherPeopleRDD); {% endhighlight %}
-Spark SQL can automatically infer the schema of a JSON dataset and load it as a SchemaRDD. +Spark SQL can automatically infer the schema of a JSON dataset and load it as a DataFrame. This conversion can be done using one of two methods in a SQLContext: * `jsonFile` - loads data from a directory of JSON files where each line of the files is a JSON object. @@ -721,7 +731,7 @@ sqlContext = SQLContext(sc) # A JSON dataset is pointed to by path. # The path can be either a single text file or a directory storing text files. path = "examples/src/main/resources/people.json" -# Create a SchemaRDD from the file(s) pointed to by path +# Create a DataFrame from the file(s) pointed to by path people = sqlContext.jsonFile(path) # The inferred schema can be visualized using the printSchema() method. @@ -730,13 +740,13 @@ people.printSchema() # |-- age: integer (nullable = true) # |-- name: string (nullable = true) -# Register this SchemaRDD as a table. +# Register this DataFrame as a table. people.registerTempTable("people") # SQL statements can be run by using the sql methods provided by sqlContext. teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19") -# Alternatively, a SchemaRDD can be created for a JSON dataset represented by +# Alternatively, a DataFrame can be created for a JSON dataset represented by # an RDD[String] storing one JSON object per string. anotherPeopleRDD = sc.parallelize([ '{"name":"Yin","address":{"city":"Columbus","state":"Ohio"}}']) @@ -782,14 +792,14 @@ sqlContext.sql("FROM src SELECT key, value").collect().foreach(println)
-When working with Hive one must construct a `JavaHiveContext`, which inherits from `JavaSQLContext`, and +When working with Hive one must construct a `HiveContext`, which inherits from `SQLContext`, and adds support for finding tables in the MetaStore and writing queries using HiveQL. In addition to -the `sql` method a `JavaHiveContext` also provides an `hql` methods, which allows queries to be +the `sql` method a `HiveContext` also provides an `hql` methods, which allows queries to be expressed in HiveQL. {% highlight java %} // sc is an existing JavaSparkContext. -JavaHiveContext sqlContext = new org.apache.spark.sql.hive.api.java.HiveContext(sc); +HiveContext sqlContext = new org.apache.spark.sql.hive.HiveContext(sc); sqlContext.sql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)"); sqlContext.sql("LOAD DATA LOCAL INPATH 'examples/src/main/resources/kv1.txt' INTO TABLE src"); @@ -831,7 +841,7 @@ turning on some experimental options. ## Caching Data In Memory -Spark SQL can cache tables using an in-memory columnar format by calling `sqlContext.cacheTable("tableName")` or `schemaRDD.cache()`. +Spark SQL can cache tables using an in-memory columnar format by calling `sqlContext.cacheTable("tableName")` or `dataFrame.cache()`. Then Spark SQL will scan only required columns and will automatically tune compression to minimize memory usage and GC pressure. You can call `sqlContext.uncacheTable("tableName")` to remove the table from memory. @@ -1098,7 +1108,7 @@ in Hive deployments. have the same input format. * Non-equi outer join: For the uncommon use case of using outer joins with non-equi join conditions (e.g. condition "`key < 10`"), Spark SQL will output wrong result for the `NULL` tuple. -* `UNION` type and `DATE` type +* `UNION` type * Unique join * Single query multi insert * Column statistics collecting: Spark SQL does not piggyback scans to collect column statistics at @@ -1151,7 +1161,7 @@ teenagers.map(t => "Name: " + t(0)).collect().foreach(println) The DSL uses Scala symbols to represent columns in the underlying table, which are identifiers prefixed with a tick (`'`). Implicit conversions turn these symbols into expressions that are evaluated by the SQL execution engine. A full list of the functions supported can be found in the -[ScalaDoc](api/scala/index.html#org.apache.spark.sql.SchemaRDD). +[ScalaDoc](api/scala/index.html#org.apache.spark.sql.DataFrame). @@ -1333,9 +1343,9 @@ import org.apache.spark.sql._
All data types of Spark SQL are located in the package of -`org.apache.spark.sql.api.java`. To access or create a data type, +`org.apache.spark.sql.types`. To access or create a data type, please use factory methods provided in -`org.apache.spark.sql.api.java.DataType`. +`org.apache.spark.sql.types.DataTypes`. @@ -1346,109 +1356,110 @@ please use factory methods provided in @@ -1458,7 +1469,7 @@ please use factory methods provided in
ByteType byte or Byte - DataType.ByteType + DataTypes.ByteType
ShortType short or Short - DataType.ShortType + DataTypes.ShortType
IntegerType int or Integer - DataType.IntegerType + DataTypes.IntegerType
LongType long or Long - DataType.LongType + DataTypes.LongType
FloatType float or Float - DataType.FloatType + DataTypes.FloatType
DoubleType double or Double - DataType.DoubleType + DataTypes.DoubleType
DecimalType java.math.BigDecimal - DataType.DecimalType + DataTypes.createDecimalType()
+ DataTypes.createDecimalType(precision, scale).
StringType String - DataType.StringType + DataTypes.StringType
BinaryType byte[] - DataType.BinaryType + DataTypes.BinaryType
BooleanType boolean or Boolean - DataType.BooleanType + DataTypes.BooleanType
TimestampType java.sql.Timestamp - DataType.TimestampType + DataTypes.TimestampType
DateType java.sql.Date - DataType.DateType + DataTypes.DateType
ArrayType java.util.List - DataType.createArrayType(elementType)
+ DataTypes.createArrayType(elementType)
Note: The value of containsNull will be true
- DataType.createArrayType(elementType, containsNull). + DataTypes.createArrayType(elementType, containsNull).
MapType java.util.Map - DataType.createMapType(keyType, valueType)
+ DataTypes.createMapType(keyType, valueType)
Note: The value of valueContainsNull will be true.
- DataType.createMapType(keyType, valueType, valueContainsNull)
+ DataTypes.createMapType(keyType, valueType, valueContainsNull)
StructType org.apache.spark.sql.api.java.Row - DataType.createStructType(fields)
+ DataTypes.createStructType(fields)
Note: fields is a List or an array of StructFields. Also, two fields with the same name are not allowed.
The value type in Java of the data type of this field (For example, int for a StructField with the data type IntegerType) - DataType.createStructField(name, dataType, nullable) + DataTypes.createStructField(name, dataType, nullable)
diff --git a/docs/streaming-flume-integration.md b/docs/streaming-flume-integration.md index ac01dd3d8019a..40e17246fea83 100644 --- a/docs/streaming-flume-integration.md +++ b/docs/streaming-flume-integration.md @@ -64,7 +64,7 @@ configuring Flume agents. 3. **Deploying:** Package `spark-streaming-flume_{{site.SCALA_BINARY_VERSION}}` and its dependencies (except `spark-core_{{site.SCALA_BINARY_VERSION}}` and `spark-streaming_{{site.SCALA_BINARY_VERSION}}` which are provided by `spark-submit`) into the application JAR. Then use `spark-submit` to launch your application (see [Deploying section](streaming-programming-guide.html#deploying-applications) in the main programming guide). -## Approach 2 (Experimental): Pull-based Approach using a Custom Sink +## Approach 2: Pull-based Approach using a Custom Sink Instead of Flume pushing data directly to Spark Streaming, this approach runs a custom Flume sink that allows the following. - Flume pushes data into the sink, and the data stays buffered. diff --git a/docs/streaming-kafka-integration.md b/docs/streaming-kafka-integration.md index 0e38fe2144e9f..77c0abbbacbd0 100644 --- a/docs/streaming-kafka-integration.md +++ b/docs/streaming-kafka-integration.md @@ -29,7 +29,7 @@ title: Spark Streaming + Kafka Integration Guide streamingContext, [zookeeperQuorum], [group id of the consumer], [per-topic number of Kafka partitions to consume]); See the [API docs](api/java/index.html?org/apache/spark/streaming/kafka/KafkaUtils.html) - and the [example]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java). + and the [example]({{site.SPARK_GITHUB_URL}}/blob/master/examples/scala-2.10/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java).
diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index e37a2bb37b9a4..815c98713b738 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -1,6 +1,8 @@ --- layout: global -title: Spark Streaming Programming Guide +displayTitle: Spark Streaming Programming Guide +title: Spark Streaming +description: Spark Streaming programming guide and tutorial for Spark SPARK_VERSION_SHORT --- * This will become a table of contents (this text will be scraped). @@ -876,6 +878,12 @@ This is applied on a DStream containing words (say, the `pairs` DStream containi val runningCounts = pairs.updateStateByKey[Int](updateFunction _) {% endhighlight %} +The update function will be called for each word, with `newValues` having a sequence of 1's (from +the `(word, 1)` pairs) and the `runningCount` having the previous count. For the complete +Scala code, take a look at the example +[StatefulNetworkWordCount.scala]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/scala/org/apache +/spark/examples/streaming/StatefulNetworkWordCount.scala). +
@@ -897,6 +905,12 @@ This is applied on a DStream containing words (say, the `pairs` DStream containi JavaPairDStream runningCounts = pairs.updateStateByKey(updateFunction); {% endhighlight %} +The update function will be called for each word, with `newValues` having a sequence of 1's (from +the `(word, 1)` pairs) and the `runningCount` having the previous count. For the complete +Java code, take a look at the example +[JavaStatefulNetworkWordCount.java]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/java/org/apache/spark/examples/streaming +/JavaStatefulNetworkWordCount.java). +
@@ -914,14 +928,14 @@ This is applied on a DStream containing words (say, the `pairs` DStream containi runningCounts = pairs.updateStateByKey(updateFunction) {% endhighlight %} -
-
- The update function will be called for each word, with `newValues` having a sequence of 1's (from the `(word, 1)` pairs) and the `runningCount` having the previous count. For the complete -Scala code, take a look at the example +Python code, take a look at the example [stateful_network_wordcount.py]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/python/streaming/stateful_network_wordcount.py). +
+
+ Note that using `updateStateByKey` requires the checkpoint directory to be configured, which is discussed in detail in the [checkpointing](#checkpointing) section. diff --git a/docs/submitting-applications.md b/docs/submitting-applications.md index 3bd1deaccfafe..57b074778f2b0 100644 --- a/docs/submitting-applications.md +++ b/docs/submitting-applications.md @@ -58,8 +58,8 @@ for applications that involve the REPL (e.g. Spark shell). Alternatively, if your application is submitted from a machine far from the worker machines (e.g. locally on your laptop), it is common to use `cluster` mode to minimize network latency between -the drivers and the executors. Note that `cluster` mode is currently not supported for standalone -clusters, Mesos clusters, or Python applications. +the drivers and the executors. Note that `cluster` mode is currently not supported for +Mesos clusters or Python applications. For Python applications, simply pass a `.py` file in the place of `` instead of a JAR, and add Python `.zip`, `.egg` or `.py` files to the search path with `--py-files`. @@ -174,6 +174,11 @@ This can use up a significant amount of space over time and will need to be clea is handled automatically, and with Spark standalone, automatic cleanup can be configured with the `spark.worker.cleanup.appDataTtl` property. +Users may also include any other dependencies by supplying a comma-delimited list of maven coordinates +with `--packages`. All transitive dependencies will be handled when using this command. Additional +repositories (or resolvers in SBT) can be added in a comma-delimited fashion with the flag `--repositories`. +These commands can be used with `pyspark`, `spark-shell`, and `spark-submit` to include Spark Packages. + For Python, the equivalent `--py-files` option can be used to distribute `.egg`, `.zip` and `.py` libraries to executors. diff --git a/docs/tuning.md b/docs/tuning.md index efaac9d3d405f..cbd227868b248 100644 --- a/docs/tuning.md +++ b/docs/tuning.md @@ -1,6 +1,8 @@ --- layout: global -title: Tuning Spark +displayTitle: Tuning Spark +title: Tuning +description: Tuning and performance optimization guide for Spark SPARK_VERSION_SHORT --- * This will become a table of contents (this text will be scraped). diff --git a/ec2/spark-ec2 b/ec2/spark-ec2 index 3abd3f396f605..26e7d22655694 100755 --- a/ec2/spark-ec2 +++ b/ec2/spark-ec2 @@ -20,6 +20,6 @@ # Preserve the user's CWD so that relative paths are passed correctly to #+ the underlying Python script. -SPARK_EC2_DIR="$(dirname $0)" +SPARK_EC2_DIR="$(dirname "$0")" python -Wdefault "${SPARK_EC2_DIR}/spark_ec2.py" "$@" diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py index abab209a05ba0..c59ab565c6862 100755 --- a/ec2/spark_ec2.py +++ b/ec2/spark_ec2.py @@ -24,14 +24,17 @@ import hashlib import logging import os +import os.path import pipes import random import shutil import string +from stat import S_IRUSR import subprocess import sys import tarfile import tempfile +import textwrap import time import urllib2 import warnings @@ -39,6 +42,9 @@ from optparse import OptionParser from sys import stderr +SPARK_EC2_VERSION = "1.2.1" +SPARK_EC2_DIR = os.path.dirname(os.path.realpath(__file__)) + VALID_SPARK_VERSIONS = set([ "0.7.3", "0.8.0", @@ -52,15 +58,15 @@ "1.1.0", "1.1.1", "1.2.0", + "1.2.1", ]) -DEFAULT_SPARK_VERSION = "1.2.0" +DEFAULT_SPARK_VERSION = SPARK_EC2_VERSION DEFAULT_SPARK_GITHUB_REPO = "https://github.com/apache/spark" -SPARK_EC2_DIR = os.path.dirname(os.path.realpath(__file__)) -MESOS_SPARK_EC2_BRANCH = "branch-1.3" -# A URL prefix from which to fetch AMI information -AMI_PREFIX = "https://raw.github.com/mesos/spark-ec2/{b}/ami-list".format(b=MESOS_SPARK_EC2_BRANCH) +# Default location to get the spark-ec2 scripts (and ami-list) from +DEFAULT_SPARK_EC2_GITHUB_REPO = "https://github.com/mesos/spark-ec2" +DEFAULT_SPARK_EC2_BRANCH = "branch-1.3" def setup_boto(): @@ -103,12 +109,11 @@ class UsageError(Exception): # Configure and parse our command-line arguments def parse_args(): parser = OptionParser( - usage="spark-ec2 [options] " - + "\n\n can be: launch, destroy, login, stop, start, get-master, reboot-slaves", - add_help_option=False) - parser.add_option( - "-h", "--help", action="help", - help="Show this help message and exit") + prog="spark-ec2", + version="%prog {v}".format(v=SPARK_EC2_VERSION), + usage="%prog [options] \n\n" + + " can be: launch, destroy, login, stop, start, get-master, reboot-slaves") + parser.add_option( "-s", "--slaves", type="int", default=1, help="Number of slaves to launch (default: %default)") @@ -130,13 +135,15 @@ def parse_args(): help="Master instance type (leave empty for same as instance-type)") parser.add_option( "-r", "--region", default="us-east-1", - help="EC2 region zone to launch instances in") + help="EC2 region used to launch instances in, or to find them in") parser.add_option( "-z", "--zone", default="", help="Availability zone to launch instances in, or 'all' to spread " + "slaves across multiple (an additional $0.01/Gb for bandwidth" + "between zones applies) (default: a single zone chosen at random)") - parser.add_option("-a", "--ami", help="Amazon Machine Image ID to use") + parser.add_option( + "-a", "--ami", + help="Amazon Machine Image ID to use") parser.add_option( "-v", "--spark-version", default=DEFAULT_SPARK_VERSION, help="Version of Spark to use: 'X.Y.Z' or a specific git hash (default: %default)") @@ -144,6 +151,14 @@ def parse_args(): "--spark-git-repo", default=DEFAULT_SPARK_GITHUB_REPO, help="Github repo from which to checkout supplied commit hash (default: %default)") + parser.add_option( + "--spark-ec2-git-repo", + default=DEFAULT_SPARK_EC2_GITHUB_REPO, + help="Github repo from which to checkout spark-ec2 (default: %default)") + parser.add_option( + "--spark-ec2-git-branch", + default=DEFAULT_SPARK_EC2_BRANCH, + help="Github repo branch of spark-ec2 to use (default: %default)") parser.add_option( "--hadoop-major-version", default="1", help="Major version of Hadoop (default: %default)") @@ -168,10 +183,11 @@ def parse_args(): "Only possible on EBS-backed AMIs. " + "EBS volumes are only attached if --ebs-vol-size > 0." + "Only support up to 8 EBS volumes.") - parser.add_option("--placement-group", type="string", default=None, - help="Which placement group to try and launch " + - "instances into. Assumes placement group is already " + - "created.") + parser.add_option( + "--placement-group", type="string", default=None, + help="Which placement group to try and launch " + + "instances into. Assumes placement group is already " + + "created.") parser.add_option( "--swap", metavar="SWAP", type="int", default=1024, help="Swap space to set up per node, in MB (default: %default)") @@ -215,9 +231,11 @@ def parse_args(): "--copy-aws-credentials", action="store_true", default=False, help="Add AWS credentials to hadoop configuration to allow Spark to access S3") parser.add_option( - "--subnet-id", default=None, help="VPC subnet to launch instances in") + "--subnet-id", default=None, + help="VPC subnet to launch instances in") parser.add_option( - "--vpc-id", default=None, help="VPC to launch instances in") + "--vpc-id", default=None, + help="VPC to launch instances in") (opts, args) = parser.parse_args() if len(args) != 2: @@ -279,58 +297,65 @@ def is_active(instance): return (instance.state in ['pending', 'running', 'stopping', 'stopped']) -# Attempt to resolve an appropriate AMI given the architecture and region of the request. # Source: http://aws.amazon.com/amazon-linux-ami/instance-type-matrix/ # Last Updated: 2014-06-20 # For easy maintainability, please keep this manually-inputted dictionary sorted by key. +EC2_INSTANCE_TYPES = { + "c1.medium": "pvm", + "c1.xlarge": "pvm", + "c3.2xlarge": "pvm", + "c3.4xlarge": "pvm", + "c3.8xlarge": "pvm", + "c3.large": "pvm", + "c3.xlarge": "pvm", + "cc1.4xlarge": "hvm", + "cc2.8xlarge": "hvm", + "cg1.4xlarge": "hvm", + "cr1.8xlarge": "hvm", + "hi1.4xlarge": "pvm", + "hs1.8xlarge": "pvm", + "i2.2xlarge": "hvm", + "i2.4xlarge": "hvm", + "i2.8xlarge": "hvm", + "i2.xlarge": "hvm", + "m1.large": "pvm", + "m1.medium": "pvm", + "m1.small": "pvm", + "m1.xlarge": "pvm", + "m2.2xlarge": "pvm", + "m2.4xlarge": "pvm", + "m2.xlarge": "pvm", + "m3.2xlarge": "hvm", + "m3.large": "hvm", + "m3.medium": "hvm", + "m3.xlarge": "hvm", + "r3.2xlarge": "hvm", + "r3.4xlarge": "hvm", + "r3.8xlarge": "hvm", + "r3.large": "hvm", + "r3.xlarge": "hvm", + "t1.micro": "pvm", + "t2.medium": "hvm", + "t2.micro": "hvm", + "t2.small": "hvm", +} + + +# Attempt to resolve an appropriate AMI given the architecture and region of the request. def get_spark_ami(opts): - instance_types = { - "c1.medium": "pvm", - "c1.xlarge": "pvm", - "c3.2xlarge": "pvm", - "c3.4xlarge": "pvm", - "c3.8xlarge": "pvm", - "c3.large": "pvm", - "c3.xlarge": "pvm", - "cc1.4xlarge": "hvm", - "cc2.8xlarge": "hvm", - "cg1.4xlarge": "hvm", - "cr1.8xlarge": "hvm", - "hi1.4xlarge": "pvm", - "hs1.8xlarge": "pvm", - "i2.2xlarge": "hvm", - "i2.4xlarge": "hvm", - "i2.8xlarge": "hvm", - "i2.xlarge": "hvm", - "m1.large": "pvm", - "m1.medium": "pvm", - "m1.small": "pvm", - "m1.xlarge": "pvm", - "m2.2xlarge": "pvm", - "m2.4xlarge": "pvm", - "m2.xlarge": "pvm", - "m3.2xlarge": "hvm", - "m3.large": "hvm", - "m3.medium": "hvm", - "m3.xlarge": "hvm", - "r3.2xlarge": "hvm", - "r3.4xlarge": "hvm", - "r3.8xlarge": "hvm", - "r3.large": "hvm", - "r3.xlarge": "hvm", - "t1.micro": "pvm", - "t2.medium": "hvm", - "t2.micro": "hvm", - "t2.small": "hvm", - } - if opts.instance_type in instance_types: - instance_type = instance_types[opts.instance_type] + if opts.instance_type in EC2_INSTANCE_TYPES: + instance_type = EC2_INSTANCE_TYPES[opts.instance_type] else: instance_type = "pvm" print >> stderr,\ "Don't recognize %s, assuming type is pvm" % opts.instance_type - ami_path = "%s/%s/%s" % (AMI_PREFIX, opts.region, instance_type) + # URL prefix from which to fetch AMI information + ami_prefix = "{r}/{b}/ami-list".format( + r=opts.spark_ec2_git_repo.replace("https://github.com", "https://raw.github.com", 1), + b=opts.spark_ec2_git_branch) + + ami_path = "%s/%s/%s" % (ami_prefix, opts.region, instance_type) try: ami = urllib2.urlopen(ami_path).read().strip() print "Spark AMI: " + ami @@ -349,6 +374,7 @@ def launch_cluster(conn, opts, cluster_name): if opts.identity_file is None: print >> stderr, "ERROR: Must provide an identity file (-i) for ssh connections." sys.exit(1) + if opts.key_pair is None: print >> stderr, "ERROR: Must provide a key pair name (-k) to use on instances." sys.exit(1) @@ -569,6 +595,9 @@ def launch_cluster(conn, opts, cluster_name): master_nodes = master_res.instances print "Launched master in %s, regid = %s" % (zone, master_res.id) + # This wait time corresponds to SPARK-4983 + print "Waiting for AWS to propagate instance metadata..." + time.sleep(5) # Give the instances descriptive names for master in master_nodes: master.add_tag( @@ -585,10 +614,9 @@ def launch_cluster(conn, opts, cluster_name): # Get the EC2 instances in an existing cluster if available. # Returns a tuple of lists of EC2 instance objects for the masters and slaves - - def get_existing_cluster(conn, opts, cluster_name, die_on_error=True): - print "Searching for existing cluster " + cluster_name + "..." + print "Searching for existing cluster " + cluster_name + " in region " \ + + opts.region + "..." reservations = conn.get_all_reservations() master_nodes = [] slave_nodes = [] @@ -606,9 +634,11 @@ def get_existing_cluster(conn, opts, cluster_name, die_on_error=True): return (master_nodes, slave_nodes) else: if master_nodes == [] and slave_nodes != []: - print >> sys.stderr, "ERROR: Could not find master in group " + cluster_name + "-master" + print >> sys.stderr, "ERROR: Could not find master in group " + cluster_name \ + + "-master" + " in region " + opts.region else: - print >> sys.stderr, "ERROR: Could not find any existing cluster" + print >> sys.stderr, "ERROR: Could not find any existing cluster" \ + + " in region " + opts.region sys.exit(1) @@ -643,12 +673,15 @@ def setup_cluster(conn, master_nodes, slave_nodes, opts, deploy_ssh_key): # NOTE: We should clone the repository before running deploy_files to # prevent ec2-variables.sh from being overwritten + print "Cloning spark-ec2 scripts from {r}/tree/{b} on master...".format( + r=opts.spark_ec2_git_repo, b=opts.spark_ec2_git_branch) ssh( host=master, opts=opts, command="rm -rf spark-ec2" + " && " - + "git clone https://github.com/mesos/spark-ec2.git -b {b}".format(b=MESOS_SPARK_EC2_BRANCH) + + "git clone {r} -b {b} spark-ec2".format(r=opts.spark_ec2_git_repo, + b=opts.spark_ec2_git_branch) ) print "Deploying files to master..." @@ -675,21 +708,32 @@ def setup_spark_cluster(master, opts): print "Ganglia started at http://%s:5080/ganglia" % master -def is_ssh_available(host, opts): +def is_ssh_available(host, opts, print_ssh_output=True): """ Check if SSH is available on a host. """ - try: - with open(os.devnull, 'w') as devnull: - ret = subprocess.check_call( - ssh_command(opts) + ['-t', '-t', '-o', 'ConnectTimeout=3', - '%s@%s' % (opts.user, host), stringify_command('true')], - stdout=devnull, - stderr=devnull - ) - return ret == 0 - except subprocess.CalledProcessError as e: - return False + s = subprocess.Popen( + ssh_command(opts) + ['-t', '-t', '-o', 'ConnectTimeout=3', + '%s@%s' % (opts.user, host), stringify_command('true')], + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT # we pipe stderr through stdout to preserve output order + ) + cmd_output = s.communicate()[0] # [1] is stderr, which we redirected to stdout + + if s.returncode != 0 and print_ssh_output: + # extra leading newline is for spacing in wait_for_cluster_state() + print textwrap.dedent("""\n + Warning: SSH connection error. (This could be temporary.) + Host: {h} + SSH return code: {r} + SSH output: {o} + """).format( + h=host, + r=s.returncode, + o=cmd_output.strip() + ) + + return s.returncode == 0 def is_cluster_ssh_available(cluster_instances, opts): @@ -896,6 +940,7 @@ def stringify_command(parts): def ssh_args(opts): parts = ['-o', 'StrictHostKeyChecking=no'] + parts += ['-o', 'UserKnownHostsFile=/dev/null'] if opts.identity_file is not None: parts += ['-i', opts.identity_file] return parts @@ -1003,10 +1048,57 @@ def real_main(): DeprecationWarning ) + if opts.identity_file is not None: + if not os.path.exists(opts.identity_file): + print >> stderr,\ + "ERROR: The identity file '{f}' doesn't exist.".format(f=opts.identity_file) + sys.exit(1) + + file_mode = os.stat(opts.identity_file).st_mode + if not (file_mode & S_IRUSR) or not oct(file_mode)[-2:] == '00': + print >> stderr, "ERROR: The identity file must be accessible only by you." + print >> stderr, 'You can fix this with: chmod 400 "{f}"'.format(f=opts.identity_file) + sys.exit(1) + + if opts.instance_type not in EC2_INSTANCE_TYPES: + print >> stderr, "Warning: Unrecognized EC2 instance type for instance-type: {t}".format( + t=opts.instance_type) + + if opts.master_instance_type != "": + if opts.master_instance_type not in EC2_INSTANCE_TYPES: + print >> stderr, \ + "Warning: Unrecognized EC2 instance type for master-instance-type: {t}".format( + t=opts.master_instance_type) + # Since we try instance types even if we can't resolve them, we check if they resolve first + # and, if they do, see if they resolve to the same virtualization type. + if opts.instance_type in EC2_INSTANCE_TYPES and \ + opts.master_instance_type in EC2_INSTANCE_TYPES: + if EC2_INSTANCE_TYPES[opts.instance_type] != \ + EC2_INSTANCE_TYPES[opts.master_instance_type]: + print >> stderr, \ + "Error: spark-ec2 currently does not support having a master and slaves with " + \ + "different AMI virtualization types." + print >> stderr, "master instance virtualization type: {t}".format( + t=EC2_INSTANCE_TYPES[opts.master_instance_type]) + print >> stderr, "slave instance virtualization type: {t}".format( + t=EC2_INSTANCE_TYPES[opts.instance_type]) + sys.exit(1) + if opts.ebs_vol_num > 8: print >> stderr, "ebs-vol-num cannot be greater than 8" sys.exit(1) + # Prevent breaking ami_prefix (/, .git and startswith checks) + # Prevent forks with non spark-ec2 names for now. + if opts.spark_ec2_git_repo.endswith("/") or \ + opts.spark_ec2_git_repo.endswith(".git") or \ + not opts.spark_ec2_git_repo.startswith("https://github.com") or \ + not opts.spark_ec2_git_repo.endswith("spark-ec2"): + print >> stderr, "spark-ec2-git-repo must be a github repo and it must not have a " \ + "trailing / or .git. " \ + "Furthermore, we currently only support forks named spark-ec2." + sys.exit(1) + try: conn = ec2.connect_to_region(opts.region) except Exception as e: @@ -1082,11 +1174,12 @@ def real_main(): time.sleep(30) # Yes, it does have to be this long :-( for group in groups: try: - conn.delete_security_group(group.name) - print "Deleted security group " + group.name + # It is needed to use group_id to make it work with VPC + conn.delete_security_group(group_id=group.id) + print "Deleted security group %s" % group.name except boto.exception.EC2ResponseError: success = False - print "Failed to delete security group " + group.name + print "Failed to delete security group %s" % group.name # Unfortunately, group.revoke() returns True even if a rule was not # deleted, so this needs to be rerun if something fails diff --git a/examples/pom.xml b/examples/pom.xml index 4b92147725f6b..8caad2bc2e27a 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -35,12 +35,6 @@ http://spark.apache.org/ - - - com.google.guava - guava - compile - org.apache.spark spark-core_${scala.binary.version} @@ -310,69 +304,40 @@ org.apache.maven.plugins maven-shade-plugin - - - package - - shade - - - false - ${project.build.directory}/scala-${scala.binary.version}/spark-examples-${project.version}-hadoop${hadoop.version}.jar - - - *:* - - - - - com.google.guava:guava - - - ** - - - - *:* - - META-INF/*.SF - META-INF/*.DSA - META-INF/*.RSA - - - - - - com.google - org.spark-project.guava - - com.google.common.** - - - com.google.common.base.Optional** - - - - org.apache.commons.math3 - org.spark-project.commons.math3 - - - - - - reference.conf - - - log4j.properties - - - - - + + false + ${project.build.directory}/scala-${scala.binary.version}/spark-examples-${project.version}-hadoop${hadoop.version}.jar + + + *:* + + + + + *:* + + META-INF/*.SF + META-INF/*.DSA + META-INF/*.RSA + + + + + + org.apache.commons.math3 + org.spark-project.commons.math3 + + + + + + reference.conf + + + log4j.properties + + + diff --git a/examples/scala-2.10/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java b/examples/scala-2.10/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java new file mode 100644 index 0000000000000..bab9f2478e779 --- /dev/null +++ b/examples/scala-2.10/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.streaming; + +import java.util.HashMap; +import java.util.HashSet; +import java.util.Arrays; +import java.util.regex.Pattern; + +import scala.Tuple2; + +import com.google.common.collect.Lists; +import kafka.serializer.StringDecoder; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.function.*; +import org.apache.spark.streaming.api.java.*; +import org.apache.spark.streaming.kafka.KafkaUtils; +import org.apache.spark.streaming.Durations; + +/** + * Consumes messages from one or more topics in Kafka and does wordcount. + * Usage: DirectKafkaWordCount + * is a list of one or more Kafka brokers + * is a list of one or more kafka topics to consume from + * + * Example: + * $ bin/run-example streaming.KafkaWordCount broker1-host:port,broker2-host:port topic1,topic2 + */ + +public final class JavaDirectKafkaWordCount { + private static final Pattern SPACE = Pattern.compile(" "); + + public static void main(String[] args) { + if (args.length < 2) { + System.err.println("Usage: DirectKafkaWordCount \n" + + " is a list of one or more Kafka brokers\n" + + " is a list of one or more kafka topics to consume from\n\n"); + System.exit(1); + } + + StreamingExamples.setStreamingLogLevels(); + + String brokers = args[0]; + String topics = args[1]; + + // Create context with 2 second batch interval + SparkConf sparkConf = new SparkConf().setAppName("JavaDirectKafkaWordCount"); + JavaStreamingContext jssc = new JavaStreamingContext(sparkConf, Durations.seconds(2)); + + HashSet topicsSet = new HashSet(Arrays.asList(topics.split(","))); + HashMap kafkaParams = new HashMap(); + kafkaParams.put("metadata.broker.list", brokers); + + // Create direct kafka stream with brokers and topics + JavaPairInputDStream messages = KafkaUtils.createDirectStream( + jssc, + String.class, + String.class, + StringDecoder.class, + StringDecoder.class, + kafkaParams, + topicsSet + ); + + // Get the lines, split them into words, count the words and print + JavaDStream lines = messages.map(new Function, String>() { + @Override + public String call(Tuple2 tuple2) { + return tuple2._2(); + } + }); + JavaDStream words = lines.flatMap(new FlatMapFunction() { + @Override + public Iterable call(String x) { + return Lists.newArrayList(SPACE.split(x)); + } + }); + JavaPairDStream wordCounts = words.mapToPair( + new PairFunction() { + @Override + public Tuple2 call(String s) { + return new Tuple2(s, 1); + } + }).reduceByKey( + new Function2() { + @Override + public Integer call(Integer i1, Integer i2) { + return i1 + i2; + } + }); + wordCounts.print(); + + // Start the computation + jssc.start(); + jssc.awaitTermination(); + } +} diff --git a/examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala b/examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala new file mode 100644 index 0000000000000..deb08fd57b8c7 --- /dev/null +++ b/examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.streaming + +import kafka.serializer.StringDecoder + +import org.apache.spark.streaming._ +import org.apache.spark.streaming.kafka._ +import org.apache.spark.SparkConf + +/** + * Consumes messages from one or more topics in Kafka and does wordcount. + * Usage: DirectKafkaWordCount + * is a list of one or more Kafka brokers + * is a list of one or more kafka topics to consume from + * + * Example: + * $ bin/run-example streaming.KafkaWordCount broker1-host:port,broker2-host:port topic1,topic2 + */ +object DirectKafkaWordCount { + def main(args: Array[String]) { + if (args.length < 2) { + System.err.println(s""" + |Usage: DirectKafkaWordCount + | is a list of one or more Kafka brokers + | is a list of one or more kafka topics to consume from + | + """".stripMargin) + System.exit(1) + } + + StreamingExamples.setStreamingLogLevels() + + val Array(brokers, topics) = args + + // Create context with 2 second batch interval + val sparkConf = new SparkConf().setAppName("DirectKafkaWordCount") + val ssc = new StreamingContext(sparkConf, Seconds(2)) + + // Create direct kafka stream with brokers and topics + val topicsSet = topics.split(",").toSet + val kafkaParams = Map[String, String]("metadata.broker.list" -> brokers) + val messages = KafkaUtils.createDirectStream[String, String, StringDecoder, StringDecoder]( + ssc, kafkaParams, topicsSet) + + // Get the lines, split them into words, count the words and print + val lines = messages.map(_._2) + val words = lines.flatMap(_.split(" ")) + val wordCounts = words.map(x => (x, 1L)).reduceByKey(_ + _) + wordCounts.print() + + // Start the computation + ssc.start() + ssc.awaitTermination() + } +} diff --git a/examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala b/examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala index 2adc63f7ff30e..387c0e421334b 100644 --- a/examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala +++ b/examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala @@ -76,7 +76,7 @@ object KafkaWordCountProducer { val Array(brokers, topic, messagesPerSec, wordsPerMessage) = args - // Zookeper connection properties + // Zookeeper connection properties val props = new Properties() props.put("metadata.broker.list", brokers) props.put("serializer.class", "kafka.serializer.StringEncoder") diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaCrossValidatorExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaCrossValidatorExample.java index f4b4f8d8c7b2f..9bbc14ea40875 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaCrossValidatorExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaCrossValidatorExample.java @@ -33,9 +33,9 @@ import org.apache.spark.ml.tuning.CrossValidator; import org.apache.spark.ml.tuning.CrossValidatorModel; import org.apache.spark.ml.tuning.ParamGridBuilder; -import org.apache.spark.sql.api.java.JavaSQLContext; -import org.apache.spark.sql.api.java.JavaSchemaRDD; -import org.apache.spark.sql.api.java.Row; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SQLContext; /** * A simple example demonstrating model selection using CrossValidator. @@ -55,7 +55,7 @@ public class JavaCrossValidatorExample { public static void main(String[] args) { SparkConf conf = new SparkConf().setAppName("JavaCrossValidatorExample"); JavaSparkContext jsc = new JavaSparkContext(conf); - JavaSQLContext jsql = new JavaSQLContext(jsc); + SQLContext jsql = new SQLContext(jsc); // Prepare training documents, which are labeled. List localTraining = Lists.newArrayList( @@ -71,8 +71,7 @@ public static void main(String[] args) { new LabeledDocument(9L, "a e c l", 0.0), new LabeledDocument(10L, "spark compile", 1.0), new LabeledDocument(11L, "hadoop software", 0.0)); - JavaSchemaRDD training = - jsql.applySchema(jsc.parallelize(localTraining), LabeledDocument.class); + DataFrame training = jsql.createDataFrame(jsc.parallelize(localTraining), LabeledDocument.class); // Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr. Tokenizer tokenizer = new Tokenizer() @@ -113,14 +112,15 @@ public static void main(String[] args) { new Document(5L, "l m n"), new Document(6L, "mapreduce spark"), new Document(7L, "apache hadoop")); - JavaSchemaRDD test = jsql.applySchema(jsc.parallelize(localTest), Document.class); + DataFrame test = jsql.createDataFrame(jsc.parallelize(localTest), Document.class); // Make predictions on test documents. cvModel uses the best model found (lrModel). - cvModel.transform(test).registerAsTable("prediction"); - JavaSchemaRDD predictions = jsql.sql("SELECT id, text, score, prediction FROM prediction"); - for (Row r: predictions.collect()) { - System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> score=" + r.get(2) + DataFrame predictions = cvModel.transform(test); + for (Row r: predictions.select("id", "text", "probability", "prediction").collect()) { + System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> prob=" + r.get(2) + ", prediction=" + r.get(3)); } + + jsc.stop(); } } diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java new file mode 100644 index 0000000000000..19d0eb216848e --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java @@ -0,0 +1,217 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import java.util.List; + +import com.google.common.collect.Lists; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.ml.classification.Classifier; +import org.apache.spark.ml.classification.ClassificationModel; +import org.apache.spark.ml.param.IntParam; +import org.apache.spark.ml.param.ParamMap; +import org.apache.spark.ml.param.Params; +import org.apache.spark.ml.param.Params$; +import org.apache.spark.mllib.linalg.BLAS; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SQLContext; + + +/** + * A simple example demonstrating how to write your own learning algorithm using Estimator, + * Transformer, and other abstractions. + * This mimics {@link org.apache.spark.ml.classification.LogisticRegression}. + * + * Run with + *
+ * bin/run-example ml.JavaDeveloperApiExample
+ * 
+ */ +public class JavaDeveloperApiExample { + + public static void main(String[] args) throws Exception { + SparkConf conf = new SparkConf().setAppName("JavaDeveloperApiExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext jsql = new SQLContext(jsc); + + // Prepare training data. + List localTraining = Lists.newArrayList( + new LabeledPoint(1.0, Vectors.dense(0.0, 1.1, 0.1)), + new LabeledPoint(0.0, Vectors.dense(2.0, 1.0, -1.0)), + new LabeledPoint(0.0, Vectors.dense(2.0, 1.3, 1.0)), + new LabeledPoint(1.0, Vectors.dense(0.0, 1.2, -0.5))); + DataFrame training = jsql.createDataFrame(jsc.parallelize(localTraining), LabeledPoint.class); + + // Create a LogisticRegression instance. This instance is an Estimator. + MyJavaLogisticRegression lr = new MyJavaLogisticRegression(); + // Print out the parameters, documentation, and any default values. + System.out.println("MyJavaLogisticRegression parameters:\n" + lr.explainParams() + "\n"); + + // We may set parameters using setter methods. + lr.setMaxIter(10); + + // Learn a LogisticRegression model. This uses the parameters stored in lr. + MyJavaLogisticRegressionModel model = lr.fit(training); + + // Prepare test data. + List localTest = Lists.newArrayList( + new LabeledPoint(1.0, Vectors.dense(-1.0, 1.5, 1.3)), + new LabeledPoint(0.0, Vectors.dense(3.0, 2.0, -0.1)), + new LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5))); + DataFrame test = jsql.createDataFrame(jsc.parallelize(localTest), LabeledPoint.class); + + // Make predictions on test documents. cvModel uses the best model found (lrModel). + DataFrame results = model.transform(test); + double sumPredictions = 0; + for (Row r : results.select("features", "label", "prediction").collect()) { + sumPredictions += r.getDouble(2); + } + if (sumPredictions != 0.0) { + throw new Exception("MyJavaLogisticRegression predicted something other than 0," + + " even though all weights are 0!"); + } + + jsc.stop(); + } +} + +/** + * Example of defining a type of {@link Classifier}. + * + * NOTE: This is private since it is an example. In practice, you may not want it to be private. + */ +class MyJavaLogisticRegression + extends Classifier + implements Params { + + /** + * Param for max number of iterations + *

+ * NOTE: The usual way to add a parameter to a model or algorithm is to include: + * - val myParamName: ParamType + * - def getMyParamName + * - def setMyParamName + */ + IntParam maxIter = new IntParam(this, "maxIter", "max number of iterations"); + + int getMaxIter() { return (Integer) get(maxIter); } + + public MyJavaLogisticRegression() { + setMaxIter(100); + } + + // The parameter setter is in this class since it should return type MyJavaLogisticRegression. + MyJavaLogisticRegression setMaxIter(int value) { + return (MyJavaLogisticRegression) set(maxIter, value); + } + + // This method is used by fit(). + // In Java, we have to make it public since Java does not understand Scala's protected modifier. + public MyJavaLogisticRegressionModel train(DataFrame dataset, ParamMap paramMap) { + // Extract columns from data using helper method. + JavaRDD oldDataset = extractLabeledPoints(dataset, paramMap).toJavaRDD(); + + // Do learning to estimate the weight vector. + int numFeatures = oldDataset.take(1).get(0).features().size(); + Vector weights = Vectors.zeros(numFeatures); // Learning would happen here. + + // Create a model, and return it. + return new MyJavaLogisticRegressionModel(this, paramMap, weights); + } +} + +/** + * Example of defining a type of {@link ClassificationModel}. + * + * NOTE: This is private since it is an example. In practice, you may not want it to be private. + */ +class MyJavaLogisticRegressionModel + extends ClassificationModel implements Params { + + private MyJavaLogisticRegression parent_; + public MyJavaLogisticRegression parent() { return parent_; } + + private ParamMap fittingParamMap_; + public ParamMap fittingParamMap() { return fittingParamMap_; } + + private Vector weights_; + public Vector weights() { return weights_; } + + public MyJavaLogisticRegressionModel( + MyJavaLogisticRegression parent_, + ParamMap fittingParamMap_, + Vector weights_) { + this.parent_ = parent_; + this.fittingParamMap_ = fittingParamMap_; + this.weights_ = weights_; + } + + // This uses the default implementation of transform(), which reads column "features" and outputs + // columns "prediction" and "rawPrediction." + + // This uses the default implementation of predict(), which chooses the label corresponding to + // the maximum value returned by [[predictRaw()]]. + + /** + * Raw prediction for each possible label. + * The meaning of a "raw" prediction may vary between algorithms, but it intuitively gives + * a measure of confidence in each possible label (where larger = more confident). + * This internal method is used to implement [[transform()]] and output [[rawPredictionCol]]. + * + * @return vector where element i is the raw prediction for label i. + * This raw prediction may be any real number, where a larger value indicates greater + * confidence for that label. + * + * In Java, we have to make this method public since Java does not understand Scala's protected + * modifier. + */ + public Vector predictRaw(Vector features) { + double margin = BLAS.dot(features, weights_); + // There are 2 classes (binary classification), so we return a length-2 vector, + // where index i corresponds to class i (i = 0, 1). + return Vectors.dense(-margin, margin); + } + + /** + * Number of classes the label can take. 2 indicates binary classification. + */ + public int numClasses() { return 2; } + + /** + * Create a copy of the model. + * The copy is shallow, except for the embedded paramMap, which gets a deep copy. + *

+ * This is used for the defaul implementation of [[transform()]]. + * + * In Java, we have to make this method public since Java does not understand Scala's protected + * modifier. + */ + public MyJavaLogisticRegressionModel copy() { + MyJavaLogisticRegressionModel m = + new MyJavaLogisticRegressionModel(parent_, fittingParamMap_, weights_); + Params$.MODULE$.inheritValues(this.paramMap(), this, m); + return m; + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java index e25b271777ed4..4e02acce696e6 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java @@ -28,9 +28,9 @@ import org.apache.spark.ml.classification.LogisticRegression; import org.apache.spark.mllib.linalg.Vectors; import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.sql.api.java.JavaSQLContext; -import org.apache.spark.sql.api.java.JavaSchemaRDD; -import org.apache.spark.sql.api.java.Row; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SQLContext; /** * A simple example demonstrating ways to specify parameters for Estimators and Transformers. @@ -44,17 +44,17 @@ public class JavaSimpleParamsExample { public static void main(String[] args) { SparkConf conf = new SparkConf().setAppName("JavaSimpleParamsExample"); JavaSparkContext jsc = new JavaSparkContext(conf); - JavaSQLContext jsql = new JavaSQLContext(jsc); + SQLContext jsql = new SQLContext(jsc); // Prepare training data. // We use LabeledPoint, which is a JavaBean. Spark SQL can convert RDDs of JavaBeans - // into SchemaRDDs, where it uses the bean metadata to infer the schema. + // into DataFrames, where it uses the bean metadata to infer the schema. List localTraining = Lists.newArrayList( new LabeledPoint(1.0, Vectors.dense(0.0, 1.1, 0.1)), new LabeledPoint(0.0, Vectors.dense(2.0, 1.0, -1.0)), new LabeledPoint(0.0, Vectors.dense(2.0, 1.3, 1.0)), new LabeledPoint(1.0, Vectors.dense(0.0, 1.2, -0.5))); - JavaSchemaRDD training = jsql.applySchema(jsc.parallelize(localTraining), LabeledPoint.class); + DataFrame training = jsql.createDataFrame(jsc.parallelize(localTraining), LabeledPoint.class); // Create a LogisticRegression instance. This instance is an Estimator. LogisticRegression lr = new LogisticRegression(); @@ -81,7 +81,7 @@ public static void main(String[] args) { // One can also combine ParamMaps. ParamMap paramMap2 = new ParamMap(); - paramMap2.put(lr.scoreCol().w("probability")); // Change output column name + paramMap2.put(lr.probabilityCol().w("myProbability")); // Change output column name ParamMap paramMapCombined = paramMap.$plus$plus(paramMap2); // Now learn a new model using the paramMapCombined parameters. @@ -94,18 +94,18 @@ public static void main(String[] args) { new LabeledPoint(1.0, Vectors.dense(-1.0, 1.5, 1.3)), new LabeledPoint(0.0, Vectors.dense(3.0, 2.0, -0.1)), new LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5))); - JavaSchemaRDD test = jsql.applySchema(jsc.parallelize(localTest), LabeledPoint.class); + DataFrame test = jsql.createDataFrame(jsc.parallelize(localTest), LabeledPoint.class); // Make predictions on test documents using the Transformer.transform() method. // LogisticRegression.transform will only use the 'features' column. - // Note that model2.transform() outputs a 'probability' column instead of the usual 'score' - // column since we renamed the lr.scoreCol parameter previously. - model2.transform(test).registerAsTable("results"); - JavaSchemaRDD results = - jsql.sql("SELECT features, label, probability, prediction FROM results"); - for (Row r: results.collect()) { + // Note that model2.transform() outputs a 'myProbability' column instead of the usual + // 'probability' column since we renamed the lr.probabilityCol parameter previously. + DataFrame results = model2.transform(test); + for (Row r: results.select("features", "label", "myProbability", "prediction").collect()) { System.out.println("(" + r.get(0) + ", " + r.get(1) + ") -> prob=" + r.get(2) + ", prediction=" + r.get(3)); } + + jsc.stop(); } } diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java index 54f18014e4b2f..ef1ec103a879f 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java @@ -21,6 +21,7 @@ import com.google.common.collect.Lists; +import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.ml.Pipeline; import org.apache.spark.ml.PipelineModel; @@ -28,10 +29,9 @@ import org.apache.spark.ml.classification.LogisticRegression; import org.apache.spark.ml.feature.HashingTF; import org.apache.spark.ml.feature.Tokenizer; -import org.apache.spark.sql.api.java.JavaSQLContext; -import org.apache.spark.sql.api.java.JavaSchemaRDD; -import org.apache.spark.sql.api.java.Row; -import org.apache.spark.SparkConf; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SQLContext; /** * A simple text classification pipeline that recognizes "spark" from input text. It uses the Java @@ -46,7 +46,7 @@ public class JavaSimpleTextClassificationPipeline { public static void main(String[] args) { SparkConf conf = new SparkConf().setAppName("JavaSimpleTextClassificationPipeline"); JavaSparkContext jsc = new JavaSparkContext(conf); - JavaSQLContext jsql = new JavaSQLContext(jsc); + SQLContext jsql = new SQLContext(jsc); // Prepare training documents, which are labeled. List localTraining = Lists.newArrayList( @@ -54,8 +54,7 @@ public static void main(String[] args) { new LabeledDocument(1L, "b d", 0.0), new LabeledDocument(2L, "spark f g h", 1.0), new LabeledDocument(3L, "hadoop mapreduce", 0.0)); - JavaSchemaRDD training = - jsql.applySchema(jsc.parallelize(localTraining), LabeledDocument.class); + DataFrame training = jsql.createDataFrame(jsc.parallelize(localTraining), LabeledDocument.class); // Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr. Tokenizer tokenizer = new Tokenizer() @@ -80,14 +79,15 @@ public static void main(String[] args) { new Document(5L, "l m n"), new Document(6L, "mapreduce spark"), new Document(7L, "apache hadoop")); - JavaSchemaRDD test = jsql.applySchema(jsc.parallelize(localTest), Document.class); + DataFrame test = jsql.createDataFrame(jsc.parallelize(localTest), Document.class); // Make predictions on test documents. - model.transform(test).registerAsTable("prediction"); - JavaSchemaRDD predictions = jsql.sql("SELECT id, text, score, prediction FROM prediction"); - for (Row r: predictions.collect()) { - System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> score=" + r.get(2) + DataFrame predictions = model.transform(test); + for (Row r: predictions.select("id", "text", "probability", "prediction").collect()) { + System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> prob=" + r.get(2) + ", prediction=" + r.get(3)); } + + jsc.stop(); } } diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaFPGrowthExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaFPGrowthExample.java new file mode 100644 index 0000000000000..36baf5868736c --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaFPGrowthExample.java @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib; + +import java.util.ArrayList; + +import com.google.common.base.Joiner; +import com.google.common.collect.Lists; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.mllib.fpm.FPGrowth; +import org.apache.spark.mllib.fpm.FPGrowthModel; + +/** + * Java example for mining frequent itemsets using FP-growth. + * Example usage: ./bin/run-example mllib.JavaFPGrowthExample ./data/mllib/sample_fpgrowth.txt + */ +public class JavaFPGrowthExample { + + public static void main(String[] args) { + String inputFile; + double minSupport = 0.3; + int numPartition = -1; + if (args.length < 1) { + System.err.println( + "Usage: JavaFPGrowth [minSupport] [numPartition]"); + System.exit(1); + } + inputFile = args[0]; + if (args.length >= 2) { + minSupport = Double.parseDouble(args[1]); + } + if (args.length >= 3) { + numPartition = Integer.parseInt(args[2]); + } + + SparkConf sparkConf = new SparkConf().setAppName("JavaFPGrowthExample"); + JavaSparkContext sc = new JavaSparkContext(sparkConf); + + JavaRDD> transactions = sc.textFile(inputFile).map( + new Function>() { + @Override + public ArrayList call(String s) { + return Lists.newArrayList(s.split(" ")); + } + } + ); + + FPGrowthModel model = new FPGrowth() + .setMinSupport(minSupport) + .setNumPartitions(numPartition) + .run(transactions); + + for (FPGrowth.FreqItemset s: model.freqItemsets().toJavaRDD().collect()) { + System.out.println("[" + Joiner.on(",").join(s.javaItems()) + "], " + s.freq()); + } + + sc.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaLDAExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaLDAExample.java new file mode 100644 index 0000000000000..36207ae38d9a9 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaLDAExample.java @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib; + +import scala.Tuple2; + +import org.apache.spark.api.java.*; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.mllib.clustering.DistributedLDAModel; +import org.apache.spark.mllib.clustering.LDA; +import org.apache.spark.mllib.linalg.Matrix; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.SparkConf; + +public class JavaLDAExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("LDA Example"); + JavaSparkContext sc = new JavaSparkContext(conf); + + // Load and parse the data + String path = "data/mllib/sample_lda_data.txt"; + JavaRDD data = sc.textFile(path); + JavaRDD parsedData = data.map( + new Function() { + public Vector call(String s) { + String[] sarray = s.trim().split(" "); + double[] values = new double[sarray.length]; + for (int i = 0; i < sarray.length; i++) + values[i] = Double.parseDouble(sarray[i]); + return Vectors.dense(values); + } + } + ); + // Index documents with unique IDs + JavaPairRDD corpus = JavaPairRDD.fromJavaRDD(parsedData.zipWithIndex().map( + new Function, Tuple2>() { + public Tuple2 call(Tuple2 doc_id) { + return doc_id.swap(); + } + } + )); + corpus.cache(); + + // Cluster the documents into three topics using LDA + DistributedLDAModel ldaModel = new LDA().setK(3).run(corpus); + + // Output topics. Each is a distribution over words (matching word count vectors) + System.out.println("Learned topics (as distributions over vocab of " + ldaModel.vocabSize() + + " words):"); + Matrix topics = ldaModel.topicsMatrix(); + for (int topic = 0; topic < 3; topic++) { + System.out.print("Topic " + topic + ":"); + for (int word = 0; word < ldaModel.vocabSize(); word++) { + System.out.print(" " + topics.apply(word, topic)); + } + System.out.println(); + } + sc.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaPowerIterationClusteringExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaPowerIterationClusteringExample.java new file mode 100644 index 0000000000000..6c6f9768f015e --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaPowerIterationClusteringExample.java @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib; + +import scala.Tuple3; + +import com.google.common.collect.Lists; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.clustering.PowerIterationClustering; +import org.apache.spark.mllib.clustering.PowerIterationClusteringModel; + +/** + * Java example for graph clustering using power iteration clustering (PIC). + */ +public class JavaPowerIterationClusteringExample { + public static void main(String[] args) { + SparkConf sparkConf = new SparkConf().setAppName("JavaPowerIterationClusteringExample"); + JavaSparkContext sc = new JavaSparkContext(sparkConf); + + @SuppressWarnings("unchecked") + JavaRDD> similarities = sc.parallelize(Lists.newArrayList( + new Tuple3(0L, 1L, 0.9), + new Tuple3(1L, 2L, 0.9), + new Tuple3(2L, 3L, 0.9), + new Tuple3(3L, 4L, 0.1), + new Tuple3(4L, 5L, 0.9))); + + PowerIterationClustering pic = new PowerIterationClustering() + .setK(2) + .setMaxIterations(10); + PowerIterationClusteringModel model = pic.run(similarities); + + for (PowerIterationClustering.Assignment a: model.assignments().toJavaRDD().collect()) { + System.out.println(a.id() + " -> " + a.cluster()); + } + + sc.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java b/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java index 01c77bd44337e..dee794840a3e1 100644 --- a/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java +++ b/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java @@ -26,9 +26,9 @@ import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.Function; -import org.apache.spark.sql.api.java.JavaSQLContext; -import org.apache.spark.sql.api.java.JavaSchemaRDD; -import org.apache.spark.sql.api.java.Row; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SQLContext; public class JavaSparkSQL { public static class Person implements Serializable { @@ -55,7 +55,7 @@ public void setAge(int age) { public static void main(String[] args) throws Exception { SparkConf sparkConf = new SparkConf().setAppName("JavaSparkSQL"); JavaSparkContext ctx = new JavaSparkContext(sparkConf); - JavaSQLContext sqlCtx = new JavaSQLContext(ctx); + SQLContext sqlCtx = new SQLContext(ctx); System.out.println("=== Data source: RDD ==="); // Load a text file and convert each line to a Java Bean. @@ -74,15 +74,15 @@ public Person call(String line) { }); // Apply a schema to an RDD of Java Beans and register it as a table. - JavaSchemaRDD schemaPeople = sqlCtx.applySchema(people, Person.class); + DataFrame schemaPeople = sqlCtx.createDataFrame(people, Person.class); schemaPeople.registerTempTable("people"); // SQL can be run over RDDs that have been registered as tables. - JavaSchemaRDD teenagers = sqlCtx.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19"); + DataFrame teenagers = sqlCtx.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19"); - // The results of SQL queries are SchemaRDDs and support all the normal RDD operations. + // The results of SQL queries are DataFrames and support all the normal RDD operations. // The columns of a row in the result can be accessed by ordinal. - List teenagerNames = teenagers.map(new Function() { + List teenagerNames = teenagers.toJavaRDD().map(new Function() { @Override public String call(Row row) { return "Name: " + row.getString(0); @@ -93,19 +93,19 @@ public String call(Row row) { } System.out.println("=== Data source: Parquet File ==="); - // JavaSchemaRDDs can be saved as parquet files, maintaining the schema information. + // DataFrames can be saved as parquet files, maintaining the schema information. schemaPeople.saveAsParquetFile("people.parquet"); // Read in the parquet file created above. // Parquet files are self-describing so the schema is preserved. - // The result of loading a parquet file is also a JavaSchemaRDD. - JavaSchemaRDD parquetFile = sqlCtx.parquetFile("people.parquet"); + // The result of loading a parquet file is also a DataFrame. + DataFrame parquetFile = sqlCtx.parquetFile("people.parquet"); //Parquet files can also be registered as tables and then used in SQL statements. parquetFile.registerTempTable("parquetFile"); - JavaSchemaRDD teenagers2 = + DataFrame teenagers2 = sqlCtx.sql("SELECT name FROM parquetFile WHERE age >= 13 AND age <= 19"); - teenagerNames = teenagers2.map(new Function() { + teenagerNames = teenagers2.toJavaRDD().map(new Function() { @Override public String call(Row row) { return "Name: " + row.getString(0); @@ -119,8 +119,8 @@ public String call(Row row) { // A JSON dataset is pointed by path. // The path can be either a single text file or a directory storing text files. String path = "examples/src/main/resources/people.json"; - // Create a JavaSchemaRDD from the file(s) pointed by path - JavaSchemaRDD peopleFromJsonFile = sqlCtx.jsonFile(path); + // Create a DataFrame from the file(s) pointed by path + DataFrame peopleFromJsonFile = sqlCtx.jsonFile(path); // Because the schema of a JSON dataset is automatically inferred, to write queries, // it is better to take a look at what is the schema. @@ -130,15 +130,15 @@ public String call(Row row) { // |-- age: IntegerType // |-- name: StringType - // Register this JavaSchemaRDD as a table. + // Register this DataFrame as a table. peopleFromJsonFile.registerTempTable("people"); // SQL statements can be run by using the sql methods provided by sqlCtx. - JavaSchemaRDD teenagers3 = sqlCtx.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19"); + DataFrame teenagers3 = sqlCtx.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19"); - // The results of SQL queries are JavaSchemaRDDs and support all the normal RDD operations. + // The results of SQL queries are DataFrame and support all the normal RDD operations. // The columns of a row in the result can be accessed by ordinal. - teenagerNames = teenagers3.map(new Function() { + teenagerNames = teenagers3.toJavaRDD().map(new Function() { @Override public String call(Row row) { return "Name: " + row.getString(0); } }).collect(); @@ -146,14 +146,14 @@ public String call(Row row) { System.out.println(name); } - // Alternatively, a JavaSchemaRDD can be created for a JSON dataset represented by + // Alternatively, a DataFrame can be created for a JSON dataset represented by // a RDD[String] storing one JSON object per string. List jsonData = Arrays.asList( "{\"name\":\"Yin\",\"address\":{\"city\":\"Columbus\",\"state\":\"Ohio\"}}"); JavaRDD anotherPeopleRDD = ctx.parallelize(jsonData); - JavaSchemaRDD peopleFromJsonRDD = sqlCtx.jsonRDD(anotherPeopleRDD); + DataFrame peopleFromJsonRDD = sqlCtx.jsonRDD(anotherPeopleRDD.rdd()); - // Take a look at the schema of this new JavaSchemaRDD. + // Take a look at the schema of this new DataFrame. peopleFromJsonRDD.printSchema(); // The schema of anotherPeople is ... // root @@ -164,8 +164,8 @@ public String call(Row row) { peopleFromJsonRDD.registerTempTable("people2"); - JavaSchemaRDD peopleWithCity = sqlCtx.sql("SELECT name, address.city FROM people2"); - List nameAndCity = peopleWithCity.map(new Function() { + DataFrame peopleWithCity = sqlCtx.sql("SELECT name, address.city FROM people2"); + List nameAndCity = peopleWithCity.toJavaRDD().map(new Function() { @Override public String call(Row row) { return "Name: " + row.getString(0) + ", City: " + row.getString(1); diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaStatefulNetworkWordCount.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaStatefulNetworkWordCount.java new file mode 100644 index 0000000000000..d46c7107c7a21 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaStatefulNetworkWordCount.java @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.streaming; + +import java.util.Arrays; +import java.util.List; +import java.util.regex.Pattern; + +import scala.Tuple2; + +import com.google.common.base.Optional; +import com.google.common.collect.Lists; + +import org.apache.spark.HashPartitioner; +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaPairRDD; +import org.apache.spark.api.java.StorageLevels; +import org.apache.spark.api.java.function.FlatMapFunction; +import org.apache.spark.api.java.function.Function2; +import org.apache.spark.api.java.function.PairFunction; +import org.apache.spark.streaming.Durations; +import org.apache.spark.streaming.api.java.JavaDStream; +import org.apache.spark.streaming.api.java.JavaPairDStream; +import org.apache.spark.streaming.api.java.JavaReceiverInputDStream; +import org.apache.spark.streaming.api.java.JavaStreamingContext; + +/** + * Counts words cumulatively in UTF8 encoded, '\n' delimited text received from the network every + * second starting with initial value of word count. + * Usage: JavaStatefulNetworkWordCount + * and describe the TCP server that Spark Streaming would connect to receive + * data. + *

+ * To run this on your local machine, you need to first run a Netcat server + * `$ nc -lk 9999` + * and then run the example + * `$ bin/run-example + * org.apache.spark.examples.streaming.JavaStatefulNetworkWordCount localhost 9999` + */ +public class JavaStatefulNetworkWordCount { + private static final Pattern SPACE = Pattern.compile(" "); + + public static void main(String[] args) { + if (args.length < 2) { + System.err.println("Usage: JavaStatefulNetworkWordCount "); + System.exit(1); + } + + StreamingExamples.setStreamingLogLevels(); + + // Update the cumulative count function + final Function2, Optional, Optional> updateFunction = + new Function2, Optional, Optional>() { + @Override + public Optional call(List values, Optional state) { + Integer newSum = state.or(0); + for (Integer value : values) { + newSum += value; + } + return Optional.of(newSum); + } + }; + + // Create the context with a 1 second batch size + SparkConf sparkConf = new SparkConf().setAppName("JavaStatefulNetworkWordCount"); + JavaStreamingContext ssc = new JavaStreamingContext(sparkConf, Durations.seconds(1)); + ssc.checkpoint("."); + + // Initial RDD input to updateStateByKey + List> tuples = Arrays.asList(new Tuple2("hello", 1), + new Tuple2("world", 1)); + JavaPairRDD initialRDD = ssc.sc().parallelizePairs(tuples); + + JavaReceiverInputDStream lines = ssc.socketTextStream( + args[0], Integer.parseInt(args[1]), StorageLevels.MEMORY_AND_DISK_SER_2); + + JavaDStream words = lines.flatMap(new FlatMapFunction() { + @Override + public Iterable call(String x) { + return Lists.newArrayList(SPACE.split(x)); + } + }); + + JavaPairDStream wordsDstream = words.mapToPair( + new PairFunction() { + @Override + public Tuple2 call(String s) { + return new Tuple2(s, 1); + } + }); + + // This will give a Dstream made of state (which is the cumulative count of the words) + JavaPairDStream stateDstream = wordsDstream.updateStateByKey(updateFunction, + new HashPartitioner(ssc.sc().defaultParallelism()), initialRDD); + + stateDstream.print(); + ssc.start(); + ssc.awaitTermination(); + } +} diff --git a/examples/src/main/python/ml/simple_text_classification_pipeline.py b/examples/src/main/python/ml/simple_text_classification_pipeline.py new file mode 100644 index 0000000000000..d281f4fa44282 --- /dev/null +++ b/examples/src/main/python/ml/simple_text_classification_pipeline.py @@ -0,0 +1,69 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from pyspark import SparkContext +from pyspark.ml import Pipeline +from pyspark.ml.classification import LogisticRegression +from pyspark.ml.feature import HashingTF, Tokenizer +from pyspark.sql import Row, SQLContext + + +""" +A simple text classification pipeline that recognizes "spark" from +input text. This is to show how to create and configure a Spark ML +pipeline in Python. Run with: + + bin/spark-submit examples/src/main/python/ml/simple_text_classification_pipeline.py +""" + + +if __name__ == "__main__": + sc = SparkContext(appName="SimpleTextClassificationPipeline") + sqlCtx = SQLContext(sc) + + # Prepare training documents, which are labeled. + LabeledDocument = Row("id", "text", "label") + training = sc.parallelize([(0L, "a b c d e spark", 1.0), + (1L, "b d", 0.0), + (2L, "spark f g h", 1.0), + (3L, "hadoop mapreduce", 0.0)]) \ + .map(lambda x: LabeledDocument(*x)).toDF() + + # Configure an ML pipeline, which consists of tree stages: tokenizer, hashingTF, and lr. + tokenizer = Tokenizer(inputCol="text", outputCol="words") + hashingTF = HashingTF(inputCol=tokenizer.getOutputCol(), outputCol="features") + lr = LogisticRegression(maxIter=10, regParam=0.01) + pipeline = Pipeline(stages=[tokenizer, hashingTF, lr]) + + # Fit the pipeline to training documents. + model = pipeline.fit(training) + + # Prepare test documents, which are unlabeled. + Document = Row("id", "text") + test = sc.parallelize([(4L, "spark i j k"), + (5L, "l m n"), + (6L, "mapreduce spark"), + (7L, "apache hadoop")]) \ + .map(lambda x: Document(*x)).toDF() + + # Make predictions on test documents and print columns of interest. + prediction = model.transform(test) + selected = prediction.select("id", "text", "prediction") + for row in selected.collect(): + print row + + sc.stop() diff --git a/examples/src/main/python/mllib/dataset_example.py b/examples/src/main/python/mllib/dataset_example.py index 540dae785f6ea..b5a70db2b9a3c 100644 --- a/examples/src/main/python/mllib/dataset_example.py +++ b/examples/src/main/python/mllib/dataset_example.py @@ -16,7 +16,7 @@ # """ -An example of how to use SchemaRDD as a dataset for ML. Run with:: +An example of how to use DataFrame as a dataset for ML. Run with:: bin/spark-submit examples/src/main/python/mllib/dataset_example.py """ diff --git a/examples/src/main/python/mllib/gaussian_mixture_model.py b/examples/src/main/python/mllib/gaussian_mixture_model.py new file mode 100644 index 0000000000000..a2cd626c9f19d --- /dev/null +++ b/examples/src/main/python/mllib/gaussian_mixture_model.py @@ -0,0 +1,65 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +A Gaussian Mixture Model clustering program using MLlib. +""" +import sys +import random +import argparse +import numpy as np + +from pyspark import SparkConf, SparkContext +from pyspark.mllib.clustering import GaussianMixture + + +def parseVector(line): + return np.array([float(x) for x in line.split(' ')]) + + +if __name__ == "__main__": + """ + Parameters + ---------- + :param inputFile: Input file path which contains data points + :param k: Number of mixture components + :param convergenceTol: Convergence threshold. Default to 1e-3 + :param maxIterations: Number of EM iterations to perform. Default to 100 + :param seed: Random seed + """ + + parser = argparse.ArgumentParser() + parser.add_argument('inputFile', help='Input File') + parser.add_argument('k', type=int, help='Number of clusters') + parser.add_argument('--convergenceTol', default=1e-3, type=float, help='convergence threshold') + parser.add_argument('--maxIterations', default=100, type=int, help='Number of iterations') + parser.add_argument('--seed', default=random.getrandbits(19), + type=long, help='Random seed') + args = parser.parse_args() + + conf = SparkConf().setAppName("GMM") + sc = SparkContext(conf=conf) + + lines = sc.textFile(args.inputFile) + data = lines.map(parseVector) + model = GaussianMixture.train(data, args.k, args.convergenceTol, + args.maxIterations, args.seed) + for i in range(args.k): + print ("weight = ", model.weights[i], "mu = ", model.gaussians[i].mu, + "sigma = ", model.gaussians[i].sigma.toArray()) + print ("Cluster labels (first 100): ", model.predict(data).take(100)) + sc.stop() diff --git a/examples/src/main/python/mllib/gradient_boosted_trees.py b/examples/src/main/python/mllib/gradient_boosted_trees.py new file mode 100644 index 0000000000000..e647773ad9060 --- /dev/null +++ b/examples/src/main/python/mllib/gradient_boosted_trees.py @@ -0,0 +1,76 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Gradient boosted Trees classification and regression using MLlib. +""" + +import sys + +from pyspark.context import SparkContext +from pyspark.mllib.tree import GradientBoostedTrees +from pyspark.mllib.util import MLUtils + + +def testClassification(trainingData, testData): + # Train a GradientBoostedTrees model. + # Empty categoricalFeaturesInfo indicates all features are continuous. + model = GradientBoostedTrees.trainClassifier(trainingData, categoricalFeaturesInfo={}, + numIterations=30, maxDepth=4) + # Evaluate model on test instances and compute test error + predictions = model.predict(testData.map(lambda x: x.features)) + labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions) + testErr = labelsAndPredictions.filter(lambda (v, p): v != p).count() \ + / float(testData.count()) + print('Test Error = ' + str(testErr)) + print('Learned classification ensemble model:') + print(model.toDebugString()) + + +def testRegression(trainingData, testData): + # Train a GradientBoostedTrees model. + # Empty categoricalFeaturesInfo indicates all features are continuous. + model = GradientBoostedTrees.trainRegressor(trainingData, categoricalFeaturesInfo={}, + numIterations=30, maxDepth=4) + # Evaluate model on test instances and compute test error + predictions = model.predict(testData.map(lambda x: x.features)) + labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions) + testMSE = labelsAndPredictions.map(lambda (v, p): (v - p) * (v - p)).sum() \ + / float(testData.count()) + print('Test Mean Squared Error = ' + str(testMSE)) + print('Learned regression ensemble model:') + print(model.toDebugString()) + + +if __name__ == "__main__": + if len(sys.argv) > 1: + print >> sys.stderr, "Usage: gradient_boosted_trees" + exit(1) + sc = SparkContext(appName="PythonGradientBoostedTrees") + + # Load and parse the data file into an RDD of LabeledPoint. + data = MLUtils.loadLibSVMFile(sc, 'data/mllib/sample_libsvm_data.txt') + # Split the data into training and test sets (30% held out for testing) + (trainingData, testData) = data.randomSplit([0.7, 0.3]) + + print('\nRunning example of classification using GradientBoostedTrees\n') + testClassification(trainingData, testData) + + print('\nRunning example of regression using GradientBoostedTrees\n') + testRegression(trainingData, testData) + + sc.stop() diff --git a/examples/src/main/python/sql.py b/examples/src/main/python/sql.py index d2c5ca48c6cb8..47202fde7510b 100644 --- a/examples/src/main/python/sql.py +++ b/examples/src/main/python/sql.py @@ -30,18 +30,18 @@ some_rdd = sc.parallelize([Row(name="John", age=19), Row(name="Smith", age=23), Row(name="Sarah", age=18)]) - # Infer schema from the first row, create a SchemaRDD and print the schema - some_schemardd = sqlContext.inferSchema(some_rdd) - some_schemardd.printSchema() + # Infer schema from the first row, create a DataFrame and print the schema + some_df = sqlContext.createDataFrame(some_rdd) + some_df.printSchema() # Another RDD is created from a list of tuples another_rdd = sc.parallelize([("John", 19), ("Smith", 23), ("Sarah", 18)]) # Schema with two fields - person_name and person_age schema = StructType([StructField("person_name", StringType(), False), StructField("person_age", IntegerType(), False)]) - # Create a SchemaRDD by applying the schema to the RDD and print the schema - another_schemardd = sqlContext.applySchema(another_rdd, schema) - another_schemardd.printSchema() + # Create a DataFrame by applying the schema to the RDD and print the schema + another_df = sqlContext.createDataFrame(another_rdd, schema) + another_df.printSchema() # root # |-- age: integer (nullable = true) # |-- name: string (nullable = true) @@ -49,7 +49,7 @@ # A JSON dataset is pointed to by path. # The path can be either a single text file or a directory storing text files. path = os.path.join(os.environ['SPARK_HOME'], "examples/src/main/resources/people.json") - # Create a SchemaRDD from the file(s) pointed to by path + # Create a DataFrame from the file(s) pointed to by path people = sqlContext.jsonFile(path) # root # |-- person_name: string (nullable = false) @@ -61,7 +61,7 @@ # |-- age: IntegerType # |-- name: StringType - # Register this SchemaRDD as a table. + # Register this DataFrame as a table. people.registerAsTable("people") # SQL statements can be run by using the sql methods provided by sqlContext diff --git a/examples/src/main/python/status_api_demo.py b/examples/src/main/python/status_api_demo.py new file mode 100644 index 0000000000000..a33bdc475a06d --- /dev/null +++ b/examples/src/main/python/status_api_demo.py @@ -0,0 +1,67 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import time +import threading +import Queue + +from pyspark import SparkConf, SparkContext + + +def delayed(seconds): + def f(x): + time.sleep(seconds) + return x + return f + + +def call_in_background(f, *args): + result = Queue.Queue(1) + t = threading.Thread(target=lambda: result.put(f(*args))) + t.daemon = True + t.start() + return result + + +def main(): + conf = SparkConf().set("spark.ui.showConsoleProgress", "false") + sc = SparkContext(appName="PythonStatusAPIDemo", conf=conf) + + def run(): + rdd = sc.parallelize(range(10), 10).map(delayed(2)) + reduced = rdd.map(lambda x: (x, 1)).reduceByKey(lambda x, y: x + y) + return reduced.map(delayed(2)).collect() + + result = call_in_background(run) + status = sc.statusTracker() + while result.empty(): + ids = status.getJobIdsForGroup() + for id in ids: + job = status.getJobInfo(id) + print "Job", id, "status: ", job.status + for sid in job.stageIds: + info = status.getStageInfo(sid) + if info: + print "Stage %d: %d tasks total (%d active, %d complete)" % \ + (sid, info.numTasks, info.numActiveTasks, info.numCompletedTasks) + time.sleep(1) + + print "Job results are:", result.get() + sc.stop() + +if __name__ == "__main__": + main() diff --git a/examples/src/main/python/streaming/kafka_wordcount.py b/examples/src/main/python/streaming/kafka_wordcount.py new file mode 100644 index 0000000000000..ed398a82b8bb0 --- /dev/null +++ b/examples/src/main/python/streaming/kafka_wordcount.py @@ -0,0 +1,54 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" + Counts words in UTF8 encoded, '\n' delimited text received from the network every second. + Usage: network_wordcount.py + + To run this on your local machine, you need to setup Kafka and create a producer first, see + http://kafka.apache.org/documentation.html#quickstart + + and then run the example + `$ bin/spark-submit --driver-class-path external/kafka-assembly/target/scala-*/\ + spark-streaming-kafka-assembly-*.jar examples/src/main/python/streaming/kafka_wordcount.py \ + localhost:2181 test` +""" + +import sys + +from pyspark import SparkContext +from pyspark.streaming import StreamingContext +from pyspark.streaming.kafka import KafkaUtils + +if __name__ == "__main__": + if len(sys.argv) != 3: + print >> sys.stderr, "Usage: kafka_wordcount.py " + exit(-1) + + sc = SparkContext(appName="PythonStreamingKafkaWordCount") + ssc = StreamingContext(sc, 1) + + zkQuorum, topic = sys.argv[1:] + kvs = KafkaUtils.createStream(ssc, zkQuorum, "spark-streaming-consumer", {topic: 1}) + lines = kvs.map(lambda x: x[1]) + counts = lines.flatMap(lambda line: line.split(" ")) \ + .map(lambda word: (word, 1)) \ + .reduceByKey(lambda a, b: a+b) + counts.pprint() + + ssc.start() + ssc.awaitTermination() diff --git a/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala b/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala index 1b53f3edbe92e..4c129dbe2d12d 100644 --- a/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala @@ -29,7 +29,7 @@ object BroadcastTest { val blockSize = if (args.length > 3) args(3) else "4096" val sparkConf = new SparkConf().setAppName("Broadcast Test") - .set("spark.broadcast.factory", s"org.apache.spark.broadcast.${bcName}BroaddcastFactory") + .set("spark.broadcast.factory", s"org.apache.spark.broadcast.${bcName}BroadcastFactory") .set("spark.broadcast.blockSize", blockSize) val sc = new SparkContext(sparkConf) diff --git a/examples/src/main/scala/org/apache/spark/examples/DriverSubmissionTest.scala b/examples/src/main/scala/org/apache/spark/examples/DriverSubmissionTest.scala index 65251e93190f0..e757283823fc3 100644 --- a/examples/src/main/scala/org/apache/spark/examples/DriverSubmissionTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/DriverSubmissionTest.scala @@ -19,6 +19,8 @@ package org.apache.spark.examples import scala.collection.JavaConversions._ +import org.apache.spark.util.Utils + /** Prints out environmental information, sleeps, and then exits. Made to * test driver submission in the standalone scheduler. */ object DriverSubmissionTest { @@ -30,7 +32,7 @@ object DriverSubmissionTest { val numSecondsToSleep = args(0).toInt val env = System.getenv() - val properties = System.getProperties() + val properties = Utils.getSystemProperties println("Environment variables containing SPARK_TEST:") env.filter{case (k, v) => k.contains("SPARK_TEST")}.foreach(println) diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala index d8c7ef38ee46d..6c0af20461d3b 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala @@ -18,12 +18,12 @@ package org.apache.spark.examples.ml import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.SparkContext._ import org.apache.spark.ml.Pipeline import org.apache.spark.ml.classification.LogisticRegression import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator import org.apache.spark.ml.feature.{HashingTF, Tokenizer} import org.apache.spark.ml.tuning.{ParamGridBuilder, CrossValidator} +import org.apache.spark.mllib.linalg.Vector import org.apache.spark.sql.{Row, SQLContext} /** @@ -44,10 +44,10 @@ object CrossValidatorExample { val conf = new SparkConf().setAppName("CrossValidatorExample") val sc = new SparkContext(conf) val sqlContext = new SQLContext(sc) - import sqlContext._ + import sqlContext.implicits._ // Prepare training documents, which are labeled. - val training = sparkContext.parallelize(Seq( + val training = sc.parallelize(Seq( LabeledDocument(0L, "a b c d e spark", 1.0), LabeledDocument(1L, "b d", 0.0), LabeledDocument(2L, "spark f g h", 1.0), @@ -90,21 +90,21 @@ object CrossValidatorExample { crossval.setNumFolds(2) // Use 3+ in practice // Run cross-validation, and choose the best set of parameters. - val cvModel = crossval.fit(training) + val cvModel = crossval.fit(training.toDF()) // Prepare test documents, which are unlabeled. - val test = sparkContext.parallelize(Seq( + val test = sc.parallelize(Seq( Document(4L, "spark i j k"), Document(5L, "l m n"), Document(6L, "mapreduce spark"), Document(7L, "apache hadoop"))) // Make predictions on test documents. cvModel uses the best model found (lrModel). - cvModel.transform(test) - .select('id, 'text, 'score, 'prediction) + cvModel.transform(test.toDF()) + .select("id", "text", "probability", "prediction") .collect() - .foreach { case Row(id: Long, text: String, score: Double, prediction: Double) => - println("(" + id + ", " + text + ") --> score=" + score + ", prediction=" + prediction) + .foreach { case Row(id: Long, text: String, prob: Vector, prediction: Double) => + println(s"($id, $text) --> prob=$prob, prediction=$prediction") } sc.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala new file mode 100644 index 0000000000000..df26798e41b7b --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala @@ -0,0 +1,184 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml + +import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.ml.classification.{Classifier, ClassifierParams, ClassificationModel} +import org.apache.spark.ml.param.{Params, IntParam, ParamMap} +import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors} +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.sql.{DataFrame, Row, SQLContext} + + +/** + * A simple example demonstrating how to write your own learning algorithm using Estimator, + * Transformer, and other abstractions. + * This mimics [[org.apache.spark.ml.classification.LogisticRegression]]. + * Run with + * {{{ + * bin/run-example ml.DeveloperApiExample + * }}} + */ +object DeveloperApiExample { + + def main(args: Array[String]) { + val conf = new SparkConf().setAppName("DeveloperApiExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + import sqlContext.implicits._ + + // Prepare training data. + val training = sc.parallelize(Seq( + LabeledPoint(1.0, Vectors.dense(0.0, 1.1, 0.1)), + LabeledPoint(0.0, Vectors.dense(2.0, 1.0, -1.0)), + LabeledPoint(0.0, Vectors.dense(2.0, 1.3, 1.0)), + LabeledPoint(1.0, Vectors.dense(0.0, 1.2, -0.5)))) + + // Create a LogisticRegression instance. This instance is an Estimator. + val lr = new MyLogisticRegression() + // Print out the parameters, documentation, and any default values. + println("MyLogisticRegression parameters:\n" + lr.explainParams() + "\n") + + // We may set parameters using setter methods. + lr.setMaxIter(10) + + // Learn a LogisticRegression model. This uses the parameters stored in lr. + val model = lr.fit(training.toDF()) + + // Prepare test data. + val test = sc.parallelize(Seq( + LabeledPoint(1.0, Vectors.dense(-1.0, 1.5, 1.3)), + LabeledPoint(0.0, Vectors.dense(3.0, 2.0, -0.1)), + LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5)))) + + // Make predictions on test data. + val sumPredictions: Double = model.transform(test.toDF()) + .select("features", "label", "prediction") + .collect() + .map { case Row(features: Vector, label: Double, prediction: Double) => + prediction + }.sum + assert(sumPredictions == 0.0, + "MyLogisticRegression predicted something other than 0, even though all weights are 0!") + + sc.stop() + } +} + +/** + * Example of defining a parameter trait for a user-defined type of [[Classifier]]. + * + * NOTE: This is private since it is an example. In practice, you may not want it to be private. + */ +private trait MyLogisticRegressionParams extends ClassifierParams { + + /** + * Param for max number of iterations + * + * NOTE: The usual way to add a parameter to a model or algorithm is to include: + * - val myParamName: ParamType + * - def getMyParamName + * - def setMyParamName + * Here, we have a trait to be mixed in with the Estimator and Model (MyLogisticRegression + * and MyLogisticRegressionModel). We place the setter (setMaxIter) method in the Estimator + * class since the maxIter parameter is only used during training (not in the Model). + */ + val maxIter: IntParam = new IntParam(this, "maxIter", "max number of iterations") + def getMaxIter: Int = get(maxIter) +} + +/** + * Example of defining a type of [[Classifier]]. + * + * NOTE: This is private since it is an example. In practice, you may not want it to be private. + */ +private class MyLogisticRegression + extends Classifier[Vector, MyLogisticRegression, MyLogisticRegressionModel] + with MyLogisticRegressionParams { + + setMaxIter(100) // Initialize + + // The parameter setter is in this class since it should return type MyLogisticRegression. + def setMaxIter(value: Int): this.type = set(maxIter, value) + + // This method is used by fit() + override protected def train( + dataset: DataFrame, + paramMap: ParamMap): MyLogisticRegressionModel = { + // Extract columns from data using helper method. + val oldDataset = extractLabeledPoints(dataset, paramMap) + + // Do learning to estimate the weight vector. + val numFeatures = oldDataset.take(1)(0).features.size + val weights = Vectors.zeros(numFeatures) // Learning would happen here. + + // Create a model, and return it. + new MyLogisticRegressionModel(this, paramMap, weights) + } +} + +/** + * Example of defining a type of [[ClassificationModel]]. + * + * NOTE: This is private since it is an example. In practice, you may not want it to be private. + */ +private class MyLogisticRegressionModel( + override val parent: MyLogisticRegression, + override val fittingParamMap: ParamMap, + val weights: Vector) + extends ClassificationModel[Vector, MyLogisticRegressionModel] + with MyLogisticRegressionParams { + + // This uses the default implementation of transform(), which reads column "features" and outputs + // columns "prediction" and "rawPrediction." + + // This uses the default implementation of predict(), which chooses the label corresponding to + // the maximum value returned by [[predictRaw()]]. + + /** + * Raw prediction for each possible label. + * The meaning of a "raw" prediction may vary between algorithms, but it intuitively gives + * a measure of confidence in each possible label (where larger = more confident). + * This internal method is used to implement [[transform()]] and output [[rawPredictionCol]]. + * + * @return vector where element i is the raw prediction for label i. + * This raw prediction may be any real number, where a larger value indicates greater + * confidence for that label. + */ + override protected def predictRaw(features: Vector): Vector = { + val margin = BLAS.dot(features, weights) + // There are 2 classes (binary classification), so we return a length-2 vector, + // where index i corresponds to class i (i = 0, 1). + Vectors.dense(-margin, margin) + } + + /** Number of classes the label can take. 2 indicates binary classification. */ + override val numClasses: Int = 2 + + /** + * Create a copy of the model. + * The copy is shallow, except for the embedded paramMap, which gets a deep copy. + * + * This is used for the defaul implementation of [[transform()]]. + */ + override protected def copy(): MyLogisticRegressionModel = { + val m = new MyLogisticRegressionModel(parent, fittingParamMap, weights) + Params.inheritValues(this.paramMap, this, m) + m + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/MovieLensALS.scala b/examples/src/main/scala/org/apache/spark/examples/ml/MovieLensALS.scala new file mode 100644 index 0000000000000..25f21113bf622 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/MovieLensALS.scala @@ -0,0 +1,180 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml + +import scopt.OptionParser + +import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.examples.mllib.AbstractParams +import org.apache.spark.ml.recommendation.ALS +import org.apache.spark.sql.{Row, SQLContext} + +/** + * An example app for ALS on MovieLens data (http://grouplens.org/datasets/movielens/). + * Run with + * {{{ + * bin/run-example ml.MovieLensALS + * }}} + */ +object MovieLensALS { + + case class Rating(userId: Int, movieId: Int, rating: Float, timestamp: Long) + + object Rating { + def parseRating(str: String): Rating = { + val fields = str.split("::") + assert(fields.size == 4) + Rating(fields(0).toInt, fields(1).toInt, fields(2).toFloat, fields(3).toLong) + } + } + + case class Movie(movieId: Int, title: String, genres: Seq[String]) + + object Movie { + def parseMovie(str: String): Movie = { + val fields = str.split("::") + assert(fields.size == 3) + Movie(fields(0).toInt, fields(1), fields(2).split("|")) + } + } + + case class Params( + ratings: String = null, + movies: String = null, + maxIter: Int = 10, + regParam: Double = 0.1, + rank: Int = 10, + numBlocks: Int = 10) extends AbstractParams[Params] + + def main(args: Array[String]) { + val defaultParams = Params() + + val parser = new OptionParser[Params]("MovieLensALS") { + head("MovieLensALS: an example app for ALS on MovieLens data.") + opt[String]("ratings") + .required() + .text("path to a MovieLens dataset of ratings") + .action((x, c) => c.copy(ratings = x)) + opt[String]("movies") + .required() + .text("path to a MovieLens dataset of movies") + .action((x, c) => c.copy(movies = x)) + opt[Int]("rank") + .text(s"rank, default: ${defaultParams.rank}}") + .action((x, c) => c.copy(rank = x)) + opt[Int]("maxIter") + .text(s"max number of iterations, default: ${defaultParams.maxIter}") + .action((x, c) => c.copy(maxIter = x)) + opt[Double]("regParam") + .text(s"regularization parameter, default: ${defaultParams.regParam}") + .action((x, c) => c.copy(regParam = x)) + opt[Int]("numBlocks") + .text(s"number of blocks, default: ${defaultParams.numBlocks}") + .action((x, c) => c.copy(numBlocks = x)) + note( + """ + |Example command line to run this app: + | + | bin/spark-submit --class org.apache.spark.examples.ml.MovieLensALS \ + | examples/target/scala-*/spark-examples-*.jar \ + | --rank 10 --maxIter 15 --regParam 0.1 \ + | --movies data/mllib/als/sample_movielens_movies.txt \ + | --ratings data/mllib/als/sample_movielens_ratings.txt + """.stripMargin) + } + + parser.parse(args, defaultParams).map { params => + run(params) + } getOrElse { + System.exit(1) + } + } + + def run(params: Params) { + val conf = new SparkConf().setAppName(s"MovieLensALS with $params") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + import sqlContext.implicits._ + + val ratings = sc.textFile(params.ratings).map(Rating.parseRating).cache() + + val numRatings = ratings.count() + val numUsers = ratings.map(_.userId).distinct().count() + val numMovies = ratings.map(_.movieId).distinct().count() + + println(s"Got $numRatings ratings from $numUsers users on $numMovies movies.") + + val splits = ratings.randomSplit(Array(0.8, 0.2), 0L) + val training = splits(0).cache() + val test = splits(1).cache() + + val numTraining = training.count() + val numTest = test.count() + println(s"Training: $numTraining, test: $numTest.") + + ratings.unpersist(blocking = false) + + val als = new ALS() + .setUserCol("userId") + .setItemCol("movieId") + .setRank(params.rank) + .setMaxIter(params.maxIter) + .setRegParam(params.regParam) + .setNumBlocks(params.numBlocks) + + val model = als.fit(training.toDF()) + + val predictions = model.transform(test.toDF()).cache() + + // Evaluate the model. + // TODO: Create an evaluator to compute RMSE. + val mse = predictions.select("rating", "prediction").rdd + .flatMap { case Row(rating: Float, prediction: Float) => + val err = rating.toDouble - prediction + val err2 = err * err + if (err2.isNaN) { + None + } else { + Some(err2) + } + }.mean() + val rmse = math.sqrt(mse) + println(s"Test RMSE = $rmse.") + + // Inspect false positives. + // Note: We reference columns in 2 ways: + // (1) predictions("movieId") lets us specify the movieId column in the predictions + // DataFrame, rather than the movieId column in the movies DataFrame. + // (2) $"userId" specifies the userId column in the predictions DataFrame. + // We could also write predictions("userId") but do not have to since + // the movies DataFrame does not have a column "userId." + val movies = sc.textFile(params.movies).map(Movie.parseMovie).toDF() + val falsePositives = predictions.join(movies) + .where((predictions("movieId") === movies("movieId")) + && ($"rating" <= 1) && ($"prediction" >= 4)) + .select($"userId", predictions("movieId"), $"title", $"rating", $"prediction") + val numFalsePositives = falsePositives.count() + println(s"Found $numFalsePositives false positives") + if (numFalsePositives > 0) { + println(s"Example false positives:") + falsePositives.limit(100).collect().foreach(println) + } + + sc.stop() + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala index e8a2adff929cb..bf805149d0af6 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala @@ -18,7 +18,6 @@ package org.apache.spark.examples.ml import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.SparkContext._ import org.apache.spark.ml.classification.LogisticRegression import org.apache.spark.ml.param.ParamMap import org.apache.spark.mllib.linalg.{Vector, Vectors} @@ -38,12 +37,12 @@ object SimpleParamsExample { val conf = new SparkConf().setAppName("SimpleParamsExample") val sc = new SparkContext(conf) val sqlContext = new SQLContext(sc) - import sqlContext._ + import sqlContext.implicits._ // Prepare training data. - // We use LabeledPoint, which is a case class. Spark SQL can convert RDDs of Java Beans - // into SchemaRDDs, where it uses the bean metadata to infer the schema. - val training = sparkContext.parallelize(Seq( + // We use LabeledPoint, which is a case class. Spark SQL can convert RDDs of case classes + // into DataFrames, where it uses the case class metadata to infer the schema. + val training = sc.parallelize(Seq( LabeledPoint(1.0, Vectors.dense(0.0, 1.1, 0.1)), LabeledPoint(0.0, Vectors.dense(2.0, 1.0, -1.0)), LabeledPoint(0.0, Vectors.dense(2.0, 1.3, 1.0)), @@ -59,7 +58,7 @@ object SimpleParamsExample { .setRegParam(0.01) // Learn a LogisticRegression model. This uses the parameters stored in lr. - val model1 = lr.fit(training) + val model1 = lr.fit(training.toDF()) // Since model1 is a Model (i.e., a Transformer produced by an Estimator), // we can view the parameters it used during fit(). // This prints the parameter (name: value) pairs, where names are unique IDs for this @@ -73,29 +72,29 @@ object SimpleParamsExample { paramMap.put(lr.regParam -> 0.1, lr.threshold -> 0.55) // Specify multiple Params. // One can also combine ParamMaps. - val paramMap2 = ParamMap(lr.scoreCol -> "probability") // Change output column name + val paramMap2 = ParamMap(lr.probabilityCol -> "myProbability") // Change output column name val paramMapCombined = paramMap ++ paramMap2 // Now learn a new model using the paramMapCombined parameters. // paramMapCombined overrides all parameters set earlier via lr.set* methods. - val model2 = lr.fit(training, paramMapCombined) + val model2 = lr.fit(training.toDF(), paramMapCombined) println("Model 2 was fit using parameters: " + model2.fittingParamMap) - // Prepare test documents. - val test = sparkContext.parallelize(Seq( + // Prepare test data. + val test = sc.parallelize(Seq( LabeledPoint(1.0, Vectors.dense(-1.0, 1.5, 1.3)), LabeledPoint(0.0, Vectors.dense(3.0, 2.0, -0.1)), LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5)))) - // Make predictions on test documents using the Transformer.transform() method. + // Make predictions on test data using the Transformer.transform() method. // LogisticRegression.transform will only use the 'features' column. - // Note that model2.transform() outputs a 'probability' column instead of the usual 'score' - // column since we renamed the lr.scoreCol parameter previously. - model2.transform(test) - .select('features, 'label, 'probability, 'prediction) + // Note that model2.transform() outputs a 'myProbability' column instead of the usual + // 'probability' column since we renamed the lr.probabilityCol parameter previously. + model2.transform(test.toDF()) + .select("features", "label", "myProbability", "prediction") .collect() - .foreach { case Row(features: Vector, label: Double, prob: Double, prediction: Double) => - println("(" + features + ", " + label + ") -> prob=" + prob + ", prediction=" + prediction) + .foreach { case Row(features: Vector, label: Double, prob: Vector, prediction: Double) => + println(s"($features, $label) -> prob=$prob, prediction=$prediction") } sc.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala index b9a6ef0229def..6772efd2c581c 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala @@ -20,10 +20,10 @@ package org.apache.spark.examples.ml import scala.beans.BeanInfo import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.SparkContext._ import org.apache.spark.ml.Pipeline import org.apache.spark.ml.classification.LogisticRegression import org.apache.spark.ml.feature.{HashingTF, Tokenizer} +import org.apache.spark.mllib.linalg.Vector import org.apache.spark.sql.{Row, SQLContext} @BeanInfo @@ -45,10 +45,10 @@ object SimpleTextClassificationPipeline { val conf = new SparkConf().setAppName("SimpleTextClassificationPipeline") val sc = new SparkContext(conf) val sqlContext = new SQLContext(sc) - import sqlContext._ + import sqlContext.implicits._ // Prepare training documents, which are labeled. - val training = sparkContext.parallelize(Seq( + val training = sc.parallelize(Seq( LabeledDocument(0L, "a b c d e spark", 1.0), LabeledDocument(1L, "b d", 0.0), LabeledDocument(2L, "spark f g h", 1.0), @@ -69,21 +69,21 @@ object SimpleTextClassificationPipeline { .setStages(Array(tokenizer, hashingTF, lr)) // Fit the pipeline to training documents. - val model = pipeline.fit(training) + val model = pipeline.fit(training.toDF()) // Prepare test documents, which are unlabeled. - val test = sparkContext.parallelize(Seq( + val test = sc.parallelize(Seq( Document(4L, "spark i j k"), Document(5L, "l m n"), Document(6L, "mapreduce spark"), Document(7L, "apache hadoop"))) // Make predictions on test documents. - model.transform(test) - .select('id, 'text, 'score, 'prediction) + model.transform(test.toDF()) + .select("id", "text", "probability", "prediction") .collect() - .foreach { case Row(id: Long, text: String, score: Double, prediction: Double) => - println("(" + id + ", " + text + ") --> score=" + score + ", prediction=" + prediction) + .foreach { case Row(id: Long, text: String, prob: Vector, prediction: Double) => + println(s"($id, $text) --> prob=$prob, prediction=$prediction") } sc.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala index f8d83f4ec7327..e943d6c889fab 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala @@ -28,10 +28,10 @@ import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{Row, SQLContext, SchemaRDD} +import org.apache.spark.sql.{Row, SQLContext, DataFrame} /** - * An example of how to use [[org.apache.spark.sql.SchemaRDD]] as a Dataset for ML. Run with + * An example of how to use [[org.apache.spark.sql.DataFrame]] as a Dataset for ML. Run with * {{{ * ./bin/run-example org.apache.spark.examples.mllib.DatasetExample [options] * }}} @@ -47,7 +47,7 @@ object DatasetExample { val defaultParams = Params() val parser = new OptionParser[Params]("DatasetExample") { - head("Dataset: an example app using SchemaRDD as a Dataset for ML.") + head("Dataset: an example app using DataFrame as a Dataset for ML.") opt[String]("input") .text(s"input path to dataset") .action((x, c) => c.copy(input = x)) @@ -71,7 +71,7 @@ object DatasetExample { val conf = new SparkConf().setAppName(s"DatasetExample with $params") val sc = new SparkContext(conf) val sqlContext = new SQLContext(sc) - import sqlContext._ // for implicit conversions + import sqlContext.implicits._ // for implicit conversions // Load input data val origData: RDD[LabeledPoint] = params.dataFormat match { @@ -80,20 +80,20 @@ object DatasetExample { } println(s"Loaded ${origData.count()} instances from file: ${params.input}") - // Convert input data to SchemaRDD explicitly. - val schemaRDD: SchemaRDD = origData - println(s"Inferred schema:\n${schemaRDD.schema.prettyJson}") - println(s"Converted to SchemaRDD with ${schemaRDD.count()} records") + // Convert input data to DataFrame explicitly. + val df: DataFrame = origData.toDF() + println(s"Inferred schema:\n${df.schema.prettyJson}") + println(s"Converted to DataFrame with ${df.count()} records") - // Select columns, using implicit conversion to SchemaRDD. - val labelsSchemaRDD: SchemaRDD = origData.select('label) - val labels: RDD[Double] = labelsSchemaRDD.map { case Row(v: Double) => v } + // Select columns + val labelsDf: DataFrame = df.select("label") + val labels: RDD[Double] = labelsDf.map { case Row(v: Double) => v } val numLabels = labels.count() val meanLabel = labels.fold(0.0)(_ + _) / numLabels println(s"Selected label column with average value $meanLabel") - val featuresSchemaRDD: SchemaRDD = origData.select('features) - val features: RDD[Vector] = featuresSchemaRDD.map { case Row(v: Vector) => v } + val featuresDf: DataFrame = df.select("features") + val features: RDD[Vector] = featuresDf.map { case Row(v: Vector) => v } val featureSummary = features.aggregate(new MultivariateOnlineSummarizer())( (summary, feat) => summary.add(feat), (sum1, sum2) => sum1.merge(sum2)) @@ -103,13 +103,13 @@ object DatasetExample { tmpDir.deleteOnExit() val outputDir = new File(tmpDir, "dataset").toString println(s"Saving to $outputDir as Parquet file.") - schemaRDD.saveAsParquetFile(outputDir) + df.saveAsParquetFile(outputDir) println(s"Loading Parquet file with UDT from $outputDir.") val newDataset = sqlContext.parquetFile(outputDir) println(s"Schema from Parquet: ${newDataset.schema.prettyJson}") - val newFeatures = newDataset.select('features).map { case Row(v: Vector) => v } + val newFeatures = newDataset.select("features").map { case Row(v: Vector) => v } val newFeaturesSummary = newFeatures.aggregate(new MultivariateOnlineSummarizer())( (summary, feat) => summary.add(feat), (sum1, sum2) => sum1.merge(sum2)) diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala index 205d80dd02682..262fd2c9611d0 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala @@ -272,6 +272,8 @@ object DecisionTreeRunner { case Variance => impurity.Variance } + params.checkpointDir.foreach(sc.setCheckpointDir) + val strategy = new Strategy( algo = params.algo, @@ -282,7 +284,6 @@ object DecisionTreeRunner { minInstancesPerNode = params.minInstancesPerNode, minInfoGain = params.minInfoGain, useNodeIdCache = params.useNodeIdCache, - checkpointDir = params.checkpointDir, checkpointInterval = params.checkpointInterval) if (params.numTrees == 1) { val startTime = System.nanoTime() diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGmmEM.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGaussianMixture.scala similarity index 88% rename from examples/src/main/scala/org/apache/spark/examples/mllib/DenseGmmEM.scala rename to examples/src/main/scala/org/apache/spark/examples/mllib/DenseGaussianMixture.scala index 948c350953e27..df76b45e50810 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGmmEM.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGaussianMixture.scala @@ -18,17 +18,17 @@ package org.apache.spark.examples.mllib import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.mllib.clustering.GaussianMixtureEM +import org.apache.spark.mllib.clustering.GaussianMixture import org.apache.spark.mllib.linalg.Vectors /** * An example Gaussian Mixture Model EM app. Run with * {{{ - * ./bin/run-example org.apache.spark.examples.mllib.DenseGmmEM + * ./bin/run-example mllib.DenseGaussianMixture * }}} * If you use it as a template to create your own app, please use `spark-submit` to submit your app. */ -object DenseGmmEM { +object DenseGaussianMixture { def main(args: Array[String]): Unit = { if (args.length < 3) { println("usage: DenseGmmEM [maxIterations]") @@ -46,7 +46,7 @@ object DenseGmmEM { Vectors.dense(line.trim.split(' ').map(_.toDouble)) }.cache() - val clusters = new GaussianMixtureEM() + val clusters = new GaussianMixture() .setK(k) .setConvergenceTol(convergenceTol) .setMaxIterations(maxIterations) @@ -54,7 +54,7 @@ object DenseGmmEM { for (i <- 0 until clusters.k) { println("weight=%f\nmu=%s\nsigma=\n%s\n" format - (clusters.weight(i), clusters.mu(i), clusters.sigma(i))) + (clusters.weights(i), clusters.gaussians(i).mu, clusters.gaussians(i).sigma)) } println("Cluster labels (first <= 100):") diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DenseKMeans.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DenseKMeans.scala index 11e35598baf50..14cc5cbb679c5 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DenseKMeans.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DenseKMeans.scala @@ -56,7 +56,7 @@ object DenseKMeans { .text(s"number of clusters, required") .action((x, c) => c.copy(k = x)) opt[Int]("numIterations") - .text(s"number of iterations, default; ${defaultParams.numIterations}") + .text(s"number of iterations, default: ${defaultParams.numIterations}") .action((x, c) => c.copy(numIterations = x)) opt[String]("initMode") .text(s"initialization mode (${InitializationMode.values.mkString(",")}), " + diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/FPGrowthExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/FPGrowthExample.scala new file mode 100644 index 0000000000000..13f24a1e59610 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/FPGrowthExample.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib + +import scopt.OptionParser + +import org.apache.spark.mllib.fpm.FPGrowth +import org.apache.spark.{SparkConf, SparkContext} + +/** + * Example for mining frequent itemsets using FP-growth. + * Example usage: ./bin/run-example mllib.FPGrowthExample \ + * --minSupport 0.8 --numPartition 2 ./data/mllib/sample_fpgrowth.txt + */ +object FPGrowthExample { + + case class Params( + input: String = null, + minSupport: Double = 0.3, + numPartition: Int = -1) extends AbstractParams[Params] + + def main(args: Array[String]) { + val defaultParams = Params() + + val parser = new OptionParser[Params]("FPGrowthExample") { + head("FPGrowth: an example FP-growth app.") + opt[Double]("minSupport") + .text(s"minimal support level, default: ${defaultParams.minSupport}") + .action((x, c) => c.copy(minSupport = x)) + opt[Int]("numPartition") + .text(s"number of partition, default: ${defaultParams.numPartition}") + .action((x, c) => c.copy(numPartition = x)) + arg[String]("") + .text("input paths to input data set, whose file format is that each line " + + "contains a transaction with each item in String and separated by a space") + .required() + .action((x, c) => c.copy(input = x)) + } + + parser.parse(args, defaultParams).map { params => + run(params) + }.getOrElse { + sys.exit(1) + } + } + + def run(params: Params) { + val conf = new SparkConf().setAppName(s"FPGrowthExample with $params") + val sc = new SparkContext(conf) + val transactions = sc.textFile(params.input).map(_.split(" ")).cache() + + println(s"Number of transactions: ${transactions.count()}") + + val model = new FPGrowth() + .setMinSupport(params.minSupport) + .setNumPartitions(params.numPartition) + .run(transactions) + + println(s"Number of frequent itemsets: ${model.freqItemsets.count()}") + + model.freqItemsets.collect().foreach { itemset => + println(itemset.items.mkString("[", ",", "]") + ", " + itemset.freq) + } + + sc.stop() + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala new file mode 100644 index 0000000000000..11399a7633638 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala @@ -0,0 +1,283 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib + +import java.text.BreakIterator + +import scala.collection.mutable + +import scopt.OptionParser + +import org.apache.log4j.{Level, Logger} + +import org.apache.spark.{SparkContext, SparkConf} +import org.apache.spark.mllib.clustering.LDA +import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.rdd.RDD + + +/** + * An example Latent Dirichlet Allocation (LDA) app. Run with + * {{{ + * ./bin/run-example mllib.LDAExample [options] + * }}} + * If you use it as a template to create your own app, please use `spark-submit` to submit your app. + */ +object LDAExample { + + private case class Params( + input: Seq[String] = Seq.empty, + k: Int = 20, + maxIterations: Int = 10, + docConcentration: Double = -1, + topicConcentration: Double = -1, + vocabSize: Int = 10000, + stopwordFile: String = "", + checkpointDir: Option[String] = None, + checkpointInterval: Int = 10) extends AbstractParams[Params] + + def main(args: Array[String]) { + val defaultParams = Params() + + val parser = new OptionParser[Params]("LDAExample") { + head("LDAExample: an example LDA app for plain text data.") + opt[Int]("k") + .text(s"number of topics. default: ${defaultParams.k}") + .action((x, c) => c.copy(k = x)) + opt[Int]("maxIterations") + .text(s"number of iterations of learning. default: ${defaultParams.maxIterations}") + .action((x, c) => c.copy(maxIterations = x)) + opt[Double]("docConcentration") + .text(s"amount of topic smoothing to use (> 1.0) (-1=auto)." + + s" default: ${defaultParams.docConcentration}") + .action((x, c) => c.copy(docConcentration = x)) + opt[Double]("topicConcentration") + .text(s"amount of term (word) smoothing to use (> 1.0) (-1=auto)." + + s" default: ${defaultParams.topicConcentration}") + .action((x, c) => c.copy(topicConcentration = x)) + opt[Int]("vocabSize") + .text(s"number of distinct word types to use, chosen by frequency. (-1=all)" + + s" default: ${defaultParams.vocabSize}") + .action((x, c) => c.copy(vocabSize = x)) + opt[String]("stopwordFile") + .text(s"filepath for a list of stopwords. Note: This must fit on a single machine." + + s" default: ${defaultParams.stopwordFile}") + .action((x, c) => c.copy(stopwordFile = x)) + opt[String]("checkpointDir") + .text(s"Directory for checkpointing intermediate results." + + s" Checkpointing helps with recovery and eliminates temporary shuffle files on disk." + + s" default: ${defaultParams.checkpointDir}") + .action((x, c) => c.copy(checkpointDir = Some(x))) + opt[Int]("checkpointInterval") + .text(s"Iterations between each checkpoint. Only used if checkpointDir is set." + + s" default: ${defaultParams.checkpointInterval}") + .action((x, c) => c.copy(checkpointInterval = x)) + arg[String]("...") + .text("input paths (directories) to plain text corpora." + + " Each text file line should hold 1 document.") + .unbounded() + .required() + .action((x, c) => c.copy(input = c.input :+ x)) + } + + parser.parse(args, defaultParams).map { params => + run(params) + }.getOrElse { + parser.showUsageAsError + sys.exit(1) + } + } + + private def run(params: Params) { + val conf = new SparkConf().setAppName(s"LDAExample with $params") + val sc = new SparkContext(conf) + + Logger.getRootLogger.setLevel(Level.WARN) + + // Load documents, and prepare them for LDA. + val preprocessStart = System.nanoTime() + val (corpus, vocabArray, actualNumTokens) = + preprocess(sc, params.input, params.vocabSize, params.stopwordFile) + corpus.cache() + val actualCorpusSize = corpus.count() + val actualVocabSize = vocabArray.size + val preprocessElapsed = (System.nanoTime() - preprocessStart) / 1e9 + + println() + println(s"Corpus summary:") + println(s"\t Training set size: $actualCorpusSize documents") + println(s"\t Vocabulary size: $actualVocabSize terms") + println(s"\t Training set size: $actualNumTokens tokens") + println(s"\t Preprocessing time: $preprocessElapsed sec") + println() + + // Run LDA. + val lda = new LDA() + lda.setK(params.k) + .setMaxIterations(params.maxIterations) + .setDocConcentration(params.docConcentration) + .setTopicConcentration(params.topicConcentration) + .setCheckpointInterval(params.checkpointInterval) + if (params.checkpointDir.nonEmpty) { + sc.setCheckpointDir(params.checkpointDir.get) + } + val startTime = System.nanoTime() + val ldaModel = lda.run(corpus) + val elapsed = (System.nanoTime() - startTime) / 1e9 + + println(s"Finished training LDA model. Summary:") + println(s"\t Training time: $elapsed sec") + val avgLogLikelihood = ldaModel.logLikelihood / actualCorpusSize.toDouble + println(s"\t Training data average log likelihood: $avgLogLikelihood") + println() + + // Print the topics, showing the top-weighted terms for each topic. + val topicIndices = ldaModel.describeTopics(maxTermsPerTopic = 10) + val topics = topicIndices.map { case (terms, termWeights) => + terms.zip(termWeights).map { case (term, weight) => (vocabArray(term.toInt), weight) } + } + println(s"${params.k} topics:") + topics.zipWithIndex.foreach { case (topic, i) => + println(s"TOPIC $i") + topic.foreach { case (term, weight) => + println(s"$term\t$weight") + } + println() + } + sc.stop() + } + + /** + * Load documents, tokenize them, create vocabulary, and prepare documents as term count vectors. + * @return (corpus, vocabulary as array, total token count in corpus) + */ + private def preprocess( + sc: SparkContext, + paths: Seq[String], + vocabSize: Int, + stopwordFile: String): (RDD[(Long, Vector)], Array[String], Long) = { + + // Get dataset of document texts + // One document per line in each text file. + val textRDD: RDD[String] = sc.textFile(paths.mkString(",")) + + // Split text into words + val tokenizer = new SimpleTokenizer(sc, stopwordFile) + val tokenized: RDD[(Long, IndexedSeq[String])] = textRDD.zipWithIndex().map { case (text, id) => + id -> tokenizer.getWords(text) + } + tokenized.cache() + + // Counts words: RDD[(word, wordCount)] + val wordCounts: RDD[(String, Long)] = tokenized + .flatMap { case (_, tokens) => tokens.map(_ -> 1L) } + .reduceByKey(_ + _) + wordCounts.cache() + val fullVocabSize = wordCounts.count() + // Select vocab + // (vocab: Map[word -> id], total tokens after selecting vocab) + val (vocab: Map[String, Int], selectedTokenCount: Long) = { + val tmpSortedWC: Array[(String, Long)] = if (vocabSize == -1 || fullVocabSize <= vocabSize) { + // Use all terms + wordCounts.collect().sortBy(-_._2) + } else { + // Sort terms to select vocab + wordCounts.sortBy(_._2, ascending = false).take(vocabSize) + } + (tmpSortedWC.map(_._1).zipWithIndex.toMap, tmpSortedWC.map(_._2).sum) + } + + val documents = tokenized.map { case (id, tokens) => + // Filter tokens by vocabulary, and create word count vector representation of document. + val wc = new mutable.HashMap[Int, Int]() + tokens.foreach { term => + if (vocab.contains(term)) { + val termIndex = vocab(term) + wc(termIndex) = wc.getOrElse(termIndex, 0) + 1 + } + } + val indices = wc.keys.toArray.sorted + val values = indices.map(i => wc(i).toDouble) + + val sb = Vectors.sparse(vocab.size, indices, values) + (id, sb) + } + + val vocabArray = new Array[String](vocab.size) + vocab.foreach { case (term, i) => vocabArray(i) = term } + + (documents, vocabArray, selectedTokenCount) + } +} + +/** + * Simple Tokenizer. + * + * TODO: Formalize the interface, and make this a public class in mllib.feature + */ +private class SimpleTokenizer(sc: SparkContext, stopwordFile: String) extends Serializable { + + private val stopwords: Set[String] = if (stopwordFile.isEmpty) { + Set.empty[String] + } else { + val stopwordText = sc.textFile(stopwordFile).collect() + stopwordText.flatMap(_.stripMargin.split("\\s+")).toSet + } + + // Matches sequences of Unicode letters + private val allWordRegex = "^(\\p{L}*)$".r + + // Ignore words shorter than this length. + private val minWordLength = 3 + + def getWords(text: String): IndexedSeq[String] = { + + val words = new mutable.ArrayBuffer[String]() + + // Use Java BreakIterator to tokenize text into words. + val wb = BreakIterator.getWordInstance + wb.setText(text) + + // current,end index start,end of each word + var current = wb.first() + var end = wb.next() + while (end != BreakIterator.DONE) { + // Convert to lowercase + val word: String = text.substring(current, end).toLowerCase + // Remove short words and strings that aren't only letters + word match { + case allWordRegex(w) if w.length >= minWordLength && !stopwords.contains(w) => + words += w + case _ => + } + + current = end + try { + end = wb.next() + } catch { + case e: Exception => + // Ignore remaining text in line. + // This is a known bug in BreakIterator (for some Java versions), + // which fails when it sees certain characters. + end = BreakIterator.DONE + } + } + words + } + +} diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/PowerIterationClusteringExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/PowerIterationClusteringExample.scala new file mode 100644 index 0000000000000..91c9772744f18 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/PowerIterationClusteringExample.scala @@ -0,0 +1,156 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib + +import org.apache.log4j.{Level, Logger} +import scopt.OptionParser + +import org.apache.spark.mllib.clustering.PowerIterationClustering +import org.apache.spark.rdd.RDD +import org.apache.spark.{SparkConf, SparkContext} + +/** + * An example Power Iteration Clustering http://www.icml2010.org/papers/387.pdf app. + * Takes an input of K concentric circles and the number of points in the innermost circle. + * The output should be K clusters - each cluster containing precisely the points associated + * with each of the input circles. + * + * Run with + * {{{ + * ./bin/run-example mllib.PowerIterationClusteringExample [options] + * + * Where options include: + * k: Number of circles/clusters + * n: Number of sampled points on innermost circle.. There are proportionally more points + * within the outer/larger circles + * maxIterations: Number of Power Iterations + * outerRadius: radius of the outermost of the concentric circles + * }}} + * + * Here is a sample run and output: + * + * ./bin/run-example mllib.PowerIterationClusteringExample -k 3 --n 30 --maxIterations 15 + * + * Cluster assignments: 1 -> [0,1,2,3,4],2 -> [5,6,7,8,9,10,11,12,13,14], + * 0 -> [15,16,17,18,19,20,21,22,23,24,25,26,27,28,29] + * + * + * If you use it as a template to create your own app, please use `spark-submit` to submit your app. + */ +object PowerIterationClusteringExample { + + case class Params( + input: String = null, + k: Int = 3, + numPoints: Int = 5, + maxIterations: Int = 10, + outerRadius: Double = 3.0 + ) extends AbstractParams[Params] + + def main(args: Array[String]) { + val defaultParams = Params() + + val parser = new OptionParser[Params]("PIC Circles") { + head("PowerIterationClusteringExample: an example PIC app using concentric circles.") + opt[Int]('k', "k") + .text(s"number of circles (/clusters), default: ${defaultParams.k}") + .action((x, c) => c.copy(k = x)) + opt[Int]('n', "n") + .text(s"number of points in smallest circle, default: ${defaultParams.numPoints}") + .action((x, c) => c.copy(numPoints = x)) + opt[Int]("maxIterations") + .text(s"number of iterations, default: ${defaultParams.maxIterations}") + .action((x, c) => c.copy(maxIterations = x)) + opt[Int]('r', "r") + .text(s"radius of outermost circle, default: ${defaultParams.outerRadius}") + .action((x, c) => c.copy(numPoints = x)) + } + + parser.parse(args, defaultParams).map { params => + run(params) + }.getOrElse { + sys.exit(1) + } + } + + def run(params: Params) { + val conf = new SparkConf() + .setMaster("local") + .setAppName(s"PowerIterationClustering with $params") + val sc = new SparkContext(conf) + + Logger.getRootLogger.setLevel(Level.WARN) + + val circlesRdd = generateCirclesRdd(sc, params.k, params.numPoints, params.outerRadius) + val model = new PowerIterationClustering() + .setK(params.k) + .setMaxIterations(params.maxIterations) + .run(circlesRdd) + + val clusters = model.assignments.collect().groupBy(_.cluster).mapValues(_.map(_.id)) + val assignments = clusters.toList.sortBy { case (k, v) => v.length} + val assignmentsStr = assignments + .map { case (k, v) => + s"$k -> ${v.sorted.mkString("[", ",", "]")}" + }.mkString(",") + val sizesStr = assignments.map { + _._2.size + }.sorted.mkString("(", ",", ")") + println(s"Cluster assignments: $assignmentsStr\ncluster sizes: $sizesStr") + + sc.stop() + } + + def generateCircle(radius: Double, n: Int) = { + Seq.tabulate(n) { i => + val theta = 2.0 * math.Pi * i / n + (radius * math.cos(theta), radius * math.sin(theta)) + } + } + + def generateCirclesRdd(sc: SparkContext, + nCircles: Int = 3, + nPoints: Int = 30, + outerRadius: Double): RDD[(Long, Long, Double)] = { + + val radii = Array.tabulate(nCircles) { cx => outerRadius / (nCircles - cx)} + val groupSizes = Array.tabulate(nCircles) { cx => (cx + 1) * nPoints} + val points = (0 until nCircles).flatMap { cx => + generateCircle(radii(cx), groupSizes(cx)) + }.zipWithIndex + val rdd = sc.parallelize(points) + val distancesRdd = rdd.cartesian(rdd).flatMap { case (((x0, y0), i0), ((x1, y1), i1)) => + if (i0 < i1) { + Some((i0.toLong, i1.toLong, gaussianSimilarity((x0, y0), (x1, y1), 1.0))) + } else { + None + } + } + distancesRdd + } + + /** + * Gaussian Similarity: http://en.wikipedia.org/wiki/Radial_basis_function_kernel + */ + def gaussianSimilarity(p1: (Double, Double), p2: (Double, Double), sigma: Double) = { + val coeff = 1.0 / (math.sqrt(2.0 * math.Pi) * sigma) + val expCoeff = -1.0 / 2.0 * math.pow(sigma, 2.0) + val ssquares = (p1._1 - p2._1) * (p1._1 - p2._1) + (p1._2 - p2._2) * (p1._2 - p2._2) + coeff * math.exp(expCoeff * ssquares) + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLinearRegression.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLinearRegression.scala index c5bd5b0b178d9..1a95048bbfe2d 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLinearRegression.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLinearRegression.scala @@ -35,8 +35,7 @@ import org.apache.spark.streaming.{Seconds, StreamingContext} * * To run on your local machine using the two directories `trainingDir` and `testDir`, * with updates every 5 seconds, and 2 features per data point, call: - * $ bin/run-example \ - * org.apache.spark.examples.mllib.StreamingLinearRegression trainingDir testDir 5 2 + * $ bin/run-example mllib.StreamingLinearRegression trainingDir testDir 5 2 * * As you add text files to `trainingDir` the model will continuously update. * Anytime you add text files to `testDir`, you'll see predictions from the current model. diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLogisticRegression.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLogisticRegression.scala new file mode 100644 index 0000000000000..e1998099c2d78 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLogisticRegression.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib + +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.classification.StreamingLogisticRegressionWithSGD +import org.apache.spark.SparkConf +import org.apache.spark.streaming.{Seconds, StreamingContext} + +/** + * Train a logistic regression model on one stream of data and make predictions + * on another stream, where the data streams arrive as text files + * into two different directories. + * + * The rows of the text files must be labeled data points in the form + * `(y,[x1,x2,x3,...,xn])` + * Where n is the number of features, y is a binary label, and + * n must be the same for train and test. + * + * Usage: StreamingLogisticRegression + * + * To run on your local machine using the two directories `trainingDir` and `testDir`, + * with updates every 5 seconds, and 2 features per data point, call: + * $ bin/run-example mllib.StreamingLogisticRegression trainingDir testDir 5 2 + * + * As you add text files to `trainingDir` the model will continuously update. + * Anytime you add text files to `testDir`, you'll see predictions from the current model. + * + */ +object StreamingLogisticRegression { + + def main(args: Array[String]) { + + if (args.length != 4) { + System.err.println( + "Usage: StreamingLogisticRegression ") + System.exit(1) + } + + val conf = new SparkConf().setMaster("local").setAppName("StreamingLogisticRegression") + val ssc = new StreamingContext(conf, Seconds(args(2).toLong)) + + val trainingData = ssc.textFileStream(args(0)).map(LabeledPoint.parse) + val testData = ssc.textFileStream(args(1)).map(LabeledPoint.parse) + + val model = new StreamingLogisticRegressionWithSGD() + .setInitialWeights(Vectors.zeros(args(3).toInt)) + + model.trainOn(trainingData) + model.predictOnValues(testData.map(lp => (lp.label, lp.features))).print() + + ssc.start() + ssc.awaitTermination() + + } + +} diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala b/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala index 2e98b2dc30b80..6331d1c0060f8 100644 --- a/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala +++ b/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala @@ -19,6 +19,7 @@ package org.apache.spark.examples.sql import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.functions._ // One method for defining the schema of an RDD is to make a case class with the desired column // names and types. @@ -31,43 +32,43 @@ object RDDRelation { val sqlContext = new SQLContext(sc) // Importing the SQL context gives access to all the SQL functions and implicit conversions. - import sqlContext._ + import sqlContext.implicits._ - val rdd = sc.parallelize((1 to 100).map(i => Record(i, s"val_$i"))) + val df = sc.parallelize((1 to 100).map(i => Record(i, s"val_$i"))).toDF() // Any RDD containing case classes can be registered as a table. The schema of the table is // automatically inferred using scala reflection. - rdd.registerTempTable("records") + df.registerTempTable("records") // Once tables have been registered, you can run SQL queries over them. println("Result of SELECT *:") - sql("SELECT * FROM records").collect().foreach(println) + sqlContext.sql("SELECT * FROM records").collect().foreach(println) // Aggregation queries are also supported. - val count = sql("SELECT COUNT(*) FROM records").collect().head.getLong(0) + val count = sqlContext.sql("SELECT COUNT(*) FROM records").collect().head.getLong(0) println(s"COUNT(*): $count") // The results of SQL queries are themselves RDDs and support all normal RDD functions. The // items in the RDD are of type Row, which allows you to access each column by ordinal. - val rddFromSql = sql("SELECT key, value FROM records WHERE key < 10") + val rddFromSql = sqlContext.sql("SELECT key, value FROM records WHERE key < 10") println("Result of RDD.map:") rddFromSql.map(row => s"Key: ${row(0)}, Value: ${row(1)}").collect().foreach(println) // Queries can also be written using a LINQ-like Scala DSL. - rdd.where('key === 1).orderBy('value.asc).select('key).collect().foreach(println) + df.where($"key" === 1).orderBy($"value".asc).select($"key").collect().foreach(println) // Write out an RDD as a parquet file. - rdd.saveAsParquetFile("pair.parquet") + df.saveAsParquetFile("pair.parquet") // Read in parquet file. Parquet files are self-describing so the schmema is preserved. val parquetFile = sqlContext.parquetFile("pair.parquet") // Queries can be run using the DSL on parequet files just like the original RDD. - parquetFile.where('key === 1).select('value as 'a).collect().foreach(println) + parquetFile.where($"key" === 1).select($"value".as("a")).collect().foreach(println) // These files can also be registered as tables. parquetFile.registerTempTable("parquetFile") - sql("SELECT * FROM parquetFile").collect().foreach(println) + sqlContext.sql("SELECT * FROM parquetFile").collect().foreach(println) sc.stop() } diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala b/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala index 5725da1848114..b7ba60ec28155 100644 --- a/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala +++ b/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala @@ -43,7 +43,8 @@ object HiveFromSpark { // HiveContext. When not configured by the hive-site.xml, the context automatically // creates metastore_db and warehouse in the current directory. val hiveContext = new HiveContext(sc) - import hiveContext._ + import hiveContext.implicits._ + import hiveContext.sql sql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") sql(s"LOAD DATA LOCAL INPATH '${kv1File.getAbsolutePath}' INTO TABLE src") @@ -67,7 +68,7 @@ object HiveFromSpark { // You can also register RDDs as temporary tables within a HiveContext. val rdd = sc.parallelize((1 to 100).map(i => Record(i, s"val_$i"))) - rdd.registerTempTable("records") + rdd.toDF().registerTempTable("records") // Queries can then join RDD data with data stored in Hive. println("Result of SELECT *:") diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/MQTTWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/MQTTWordCount.scala index 6ff0c47793a25..f40caad322f59 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/MQTTWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/MQTTWordCount.scala @@ -17,8 +17,8 @@ package org.apache.spark.examples.streaming -import org.eclipse.paho.client.mqttv3.{MqttClient, MqttClientPersistence, MqttException, MqttMessage, MqttTopic} -import org.eclipse.paho.client.mqttv3.persist.MqttDefaultFilePersistence +import org.eclipse.paho.client.mqttv3._ +import org.eclipse.paho.client.mqttv3.persist.MemoryPersistence import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.{Seconds, StreamingContext} @@ -31,8 +31,6 @@ import org.apache.spark.SparkConf */ object MQTTPublisher { - var client: MqttClient = _ - def main(args: Array[String]) { if (args.length < 2) { System.err.println("Usage: MQTTPublisher ") @@ -42,25 +40,36 @@ object MQTTPublisher { StreamingExamples.setStreamingLogLevels() val Seq(brokerUrl, topic) = args.toSeq + + var client: MqttClient = null try { - var peristance:MqttClientPersistence =new MqttDefaultFilePersistence("/tmp") - client = new MqttClient(brokerUrl, MqttClient.generateClientId(), peristance) + val persistence = new MemoryPersistence() + client = new MqttClient(brokerUrl, MqttClient.generateClientId(), persistence) + + client.connect() + + val msgtopic = client.getTopic(topic) + val msgContent = "hello mqtt demo for spark streaming" + val message = new MqttMessage(msgContent.getBytes("utf-8")) + + while (true) { + try { + msgtopic.publish(message) + println(s"Published data. topic: {msgtopic.getName()}; Message: {message}") + } catch { + case e: MqttException if e.getReasonCode == MqttException.REASON_CODE_MAX_INFLIGHT => + Thread.sleep(10) + println("Queue is full, wait for to consume data from the message queue") + } + } } catch { case e: MqttException => println("Exception Caught: " + e) + } finally { + if (client != null) { + client.disconnect() + } } - - client.connect() - - val msgtopic: MqttTopic = client.getTopic(topic) - val msg: String = "hello mqtt demo for spark streaming" - - while (true) { - val message: MqttMessage = new MqttMessage(String.valueOf(msg).getBytes("utf-8")) - msgtopic.publish(message) - println("Published data. topic: " + msgtopic.getName() + " Message: " + message) - } - client.disconnect() } } @@ -96,9 +105,9 @@ object MQTTWordCount { val sparkConf = new SparkConf().setAppName("MQTTWordCount") val ssc = new StreamingContext(sparkConf, Seconds(2)) val lines = MQTTUtils.createStream(ssc, brokerUrl, topic, StorageLevel.MEMORY_ONLY_SER_2) - - val words = lines.flatMap(x => x.toString.split(" ")) + val words = lines.flatMap(x => x.split(" ")) val wordCounts = words.map(x => (x, 1)).reduceByKey(_ + _) + wordCounts.print() ssc.start() ssc.awaitTermination() diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeUtils.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeUtils.scala index 4b732c1592ab2..44dec45c227ca 100644 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeUtils.scala +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeUtils.scala @@ -19,7 +19,6 @@ package org.apache.spark.streaming.flume import java.net.InetSocketAddress -import org.apache.spark.annotation.Experimental import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.StreamingContext import org.apache.spark.streaming.api.java.{JavaReceiverInputDStream, JavaStreamingContext} @@ -121,7 +120,6 @@ object FlumeUtils { * @param port Port of the host at which the Spark Sink is listening * @param storageLevel Storage level to use for storing the received objects */ - @Experimental def createPollingStream( ssc: StreamingContext, hostname: String, @@ -138,7 +136,6 @@ object FlumeUtils { * @param addresses List of InetSocketAddresses representing the hosts to connect to. * @param storageLevel Storage level to use for storing the received objects */ - @Experimental def createPollingStream( ssc: StreamingContext, addresses: Seq[InetSocketAddress], @@ -159,7 +156,6 @@ object FlumeUtils { * result in this stream using more threads * @param storageLevel Storage level to use for storing the received objects */ - @Experimental def createPollingStream( ssc: StreamingContext, addresses: Seq[InetSocketAddress], @@ -178,7 +174,6 @@ object FlumeUtils { * @param hostname Hostname of the host on which the Spark Sink is running * @param port Port of the host at which the Spark Sink is listening */ - @Experimental def createPollingStream( jssc: JavaStreamingContext, hostname: String, @@ -195,7 +190,6 @@ object FlumeUtils { * @param port Port of the host at which the Spark Sink is listening * @param storageLevel Storage level to use for storing the received objects */ - @Experimental def createPollingStream( jssc: JavaStreamingContext, hostname: String, @@ -212,7 +206,6 @@ object FlumeUtils { * @param addresses List of InetSocketAddresses on which the Spark Sink is running. * @param storageLevel Storage level to use for storing the received objects */ - @Experimental def createPollingStream( jssc: JavaStreamingContext, addresses: Array[InetSocketAddress], @@ -233,7 +226,6 @@ object FlumeUtils { * result in this stream using more threads * @param storageLevel Storage level to use for storing the received objects */ - @Experimental def createPollingStream( jssc: JavaStreamingContext, addresses: Array[InetSocketAddress], diff --git a/external/flume/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java b/external/flume/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java index 1e24da7f5f60c..cfedb5a042a35 100644 --- a/external/flume/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java +++ b/external/flume/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java @@ -31,7 +31,7 @@ public void setUp() { SparkConf conf = new SparkConf() .setMaster("local[2]") .setAppName("test") - .set("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock"); + .set("spark.streaming.clock", "org.apache.spark.util.ManualClock"); ssc = new JavaStreamingContext(conf, new Duration(1000)); ssc.checkpoint("checkpoint"); } diff --git a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala index b57a1c71e35b9..e04d4088df7dc 100644 --- a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala +++ b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala @@ -34,10 +34,9 @@ import org.scalatest.{BeforeAndAfter, FunSuite} import org.apache.spark.{SparkConf, Logging} import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.dstream.ReceiverInputDStream -import org.apache.spark.streaming.util.ManualClock import org.apache.spark.streaming.{Seconds, TestOutputStream, StreamingContext} import org.apache.spark.streaming.flume.sink._ -import org.apache.spark.util.Utils +import org.apache.spark.util.{ManualClock, Utils} class FlumePollingStreamSuite extends FunSuite with BeforeAndAfter with Logging { @@ -54,7 +53,7 @@ class FlumePollingStreamSuite extends FunSuite with BeforeAndAfter with Logging def beforeFunction() { logInfo("Using manual clock") - conf.set("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock") + conf.set("spark.streaming.clock", "org.apache.spark.util.ManualClock") } before(beforeFunction()) @@ -236,7 +235,7 @@ class FlumePollingStreamSuite extends FunSuite with BeforeAndAfter with Logging tx.commit() tx.close() Thread.sleep(500) // Allow some time for the events to reach - clock.addToTime(batchDuration.milliseconds) + clock.advance(batchDuration.milliseconds) } null } diff --git a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala index f333e3891b5f0..322de7bf2fed8 100644 --- a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala +++ b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala @@ -19,13 +19,13 @@ package org.apache.spark.streaming.flume import java.net.{InetSocketAddress, ServerSocket} import java.nio.ByteBuffer -import java.nio.charset.Charset import scala.collection.JavaConversions._ import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer} import scala.concurrent.duration._ import scala.language.postfixOps +import com.google.common.base.Charsets import org.apache.avro.ipc.NettyTransceiver import org.apache.avro.ipc.specific.SpecificRequestor import org.apache.flume.source.avro @@ -108,7 +108,7 @@ class FlumeStreamSuite extends FunSuite with BeforeAndAfter with Matchers with L val inputEvents = input.map { item => val event = new AvroFlumeEvent - event.setBody(ByteBuffer.wrap(item.getBytes("UTF-8"))) + event.setBody(ByteBuffer.wrap(item.getBytes(Charsets.UTF_8))) event.setHeaders(Map[CharSequence, CharSequence]("test" -> "header")) event } @@ -138,14 +138,13 @@ class FlumeStreamSuite extends FunSuite with BeforeAndAfter with Matchers with L status should be (avro.Status.OK) } - val decoder = Charset.forName("UTF-8").newDecoder() eventually(timeout(10 seconds), interval(100 milliseconds)) { val outputEvents = outputBuffer.flatten.map { _.event } outputEvents.foreach { event => event.getHeaders.get("test") should be("header") } - val output = outputEvents.map(event => decoder.decode(event.getBody()).toString) + val output = outputEvents.map(event => new String(event.getBody.array(), Charsets.UTF_8)) output should be (input) } } diff --git a/external/kafka-assembly/pom.xml b/external/kafka-assembly/pom.xml new file mode 100644 index 0000000000000..8daa7ed608f6a --- /dev/null +++ b/external/kafka-assembly/pom.xml @@ -0,0 +1,102 @@ + + + + + 4.0.0 + + org.apache.spark + spark-parent + 1.3.0-SNAPSHOT + ../../pom.xml + + + org.apache.spark + spark-streaming-kafka-assembly_2.10 + jar + Spark Project External Kafka Assembly + http://spark.apache.org/ + + + streaming-kafka-assembly + + + + + org.apache.spark + spark-streaming-kafka_${scala.binary.version} + ${project.version} + + + org.apache.spark + spark-streaming_${scala.binary.version} + ${project.version} + provided + + + + + target/scala-${scala.binary.version}/classes + target/scala-${scala.binary.version}/test-classes + + + org.apache.maven.plugins + maven-shade-plugin + + false + + + *:* + + + + + *:* + + META-INF/*.SF + META-INF/*.DSA + META-INF/*.RSA + + + + + + + package + + shade + + + + + + reference.conf + + + log4j.properties + + + + + + + + + + + + diff --git a/external/kafka/pom.xml b/external/kafka/pom.xml index b29b0509656ba..af96138d79405 100644 --- a/external/kafka/pom.xml +++ b/external/kafka/pom.xml @@ -44,7 +44,7 @@ org.apache.kafka kafka_${scala.binary.version} - 0.8.0 + 0.8.1.1 com.sun.jmx diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/Broker.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/Broker.scala new file mode 100644 index 0000000000000..5a74febb4bd46 --- /dev/null +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/Broker.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kafka + +import org.apache.spark.annotation.Experimental + +/** + * :: Experimental :: + * Represent the host and port info for a Kafka broker. + * Differs from the Kafka project's internal kafka.cluster.Broker, which contains a server ID + */ +@Experimental +final class Broker private( + /** Broker's hostname */ + val host: String, + /** Broker's port */ + val port: Int) extends Serializable { + override def equals(obj: Any): Boolean = obj match { + case that: Broker => + this.host == that.host && + this.port == that.port + case _ => false + } + + override def hashCode: Int = { + 41 * (41 + host.hashCode) + port + } + + override def toString(): String = { + s"Broker($host, $port)" + } +} + +/** + * :: Experimental :: + * Companion object that provides methods to create instances of [[Broker]]. + */ +@Experimental +object Broker { + def create(host: String, port: Int): Broker = + new Broker(host, port) + + def apply(host: String, port: Int): Broker = + new Broker(host, port) + + def unapply(broker: Broker): Option[(String, Int)] = { + if (broker == null) { + None + } else { + Some((broker.host, broker.port)) + } + } +} diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala new file mode 100644 index 0000000000000..04e65cb3d708c --- /dev/null +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala @@ -0,0 +1,161 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kafka + + +import scala.annotation.tailrec +import scala.collection.mutable +import scala.reflect.{classTag, ClassTag} + +import kafka.common.TopicAndPartition +import kafka.message.MessageAndMetadata +import kafka.serializer.Decoder + +import org.apache.spark.{Logging, SparkException} +import org.apache.spark.rdd.RDD +import org.apache.spark.streaming.kafka.KafkaCluster.LeaderOffset +import org.apache.spark.streaming.{StreamingContext, Time} +import org.apache.spark.streaming.dstream._ + +/** + * A stream of {@link org.apache.spark.streaming.kafka.KafkaRDD} where + * each given Kafka topic/partition corresponds to an RDD partition. + * The spark configuration spark.streaming.kafka.maxRatePerPartition gives the maximum number + * of messages + * per second that each '''partition''' will accept. + * Starting offsets are specified in advance, + * and this DStream is not responsible for committing offsets, + * so that you can control exactly-once semantics. + * For an easy interface to Kafka-managed offsets, + * see {@link org.apache.spark.streaming.kafka.KafkaCluster} + * @param kafkaParams Kafka + * configuration parameters. + * Requires "metadata.broker.list" or "bootstrap.servers" to be set with Kafka broker(s), + * NOT zookeeper servers, specified in host1:port1,host2:port2 form. + * @param fromOffsets per-topic/partition Kafka offsets defining the (inclusive) + * starting point of the stream + * @param messageHandler function for translating each message into the desired type + */ +private[streaming] +class DirectKafkaInputDStream[ + K: ClassTag, + V: ClassTag, + U <: Decoder[K]: ClassTag, + T <: Decoder[V]: ClassTag, + R: ClassTag]( + @transient ssc_ : StreamingContext, + val kafkaParams: Map[String, String], + val fromOffsets: Map[TopicAndPartition, Long], + messageHandler: MessageAndMetadata[K, V] => R +) extends InputDStream[R](ssc_) with Logging { + val maxRetries = context.sparkContext.getConf.getInt( + "spark.streaming.kafka.maxRetries", 1) + + protected[streaming] override val checkpointData = + new DirectKafkaInputDStreamCheckpointData + + protected val kc = new KafkaCluster(kafkaParams) + + protected val maxMessagesPerPartition: Option[Long] = { + val ratePerSec = context.sparkContext.getConf.getInt( + "spark.streaming.kafka.maxRatePerPartition", 0) + if (ratePerSec > 0) { + val secsPerBatch = context.graph.batchDuration.milliseconds.toDouble / 1000 + Some((secsPerBatch * ratePerSec).toLong) + } else { + None + } + } + + protected var currentOffsets = fromOffsets + + @tailrec + protected final def latestLeaderOffsets(retries: Int): Map[TopicAndPartition, LeaderOffset] = { + val o = kc.getLatestLeaderOffsets(currentOffsets.keySet) + // Either.fold would confuse @tailrec, do it manually + if (o.isLeft) { + val err = o.left.get.toString + if (retries <= 0) { + throw new SparkException(err) + } else { + log.error(err) + Thread.sleep(kc.config.refreshLeaderBackoffMs) + latestLeaderOffsets(retries - 1) + } + } else { + o.right.get + } + } + + // limits the maximum number of messages per partition + protected def clamp( + leaderOffsets: Map[TopicAndPartition, LeaderOffset]): Map[TopicAndPartition, LeaderOffset] = { + maxMessagesPerPartition.map { mmp => + leaderOffsets.map { case (tp, lo) => + tp -> lo.copy(offset = Math.min(currentOffsets(tp) + mmp, lo.offset)) + } + }.getOrElse(leaderOffsets) + } + + override def compute(validTime: Time): Option[KafkaRDD[K, V, U, T, R]] = { + val untilOffsets = clamp(latestLeaderOffsets(maxRetries)) + val rdd = KafkaRDD[K, V, U, T, R]( + context.sparkContext, kafkaParams, currentOffsets, untilOffsets, messageHandler) + + currentOffsets = untilOffsets.map(kv => kv._1 -> kv._2.offset) + Some(rdd) + } + + override def start(): Unit = { + } + + def stop(): Unit = { + } + + private[streaming] + class DirectKafkaInputDStreamCheckpointData extends DStreamCheckpointData(this) { + def batchForTime = data.asInstanceOf[mutable.HashMap[ + Time, Array[OffsetRange.OffsetRangeTuple]]] + + override def update(time: Time) { + batchForTime.clear() + generatedRDDs.foreach { kv => + val a = kv._2.asInstanceOf[KafkaRDD[K, V, U, T, R]].offsetRanges.map(_.toTuple).toArray + batchForTime += kv._1 -> a + } + } + + override def cleanup(time: Time) { } + + override def restore() { + // this is assuming that the topics don't change during execution, which is true currently + val topics = fromOffsets.keySet + val leaders = kc.findLeaders(topics).fold( + errs => throw new SparkException(errs.mkString("\n")), + ok => ok + ) + + batchForTime.toSeq.sortBy(_._1)(Time.ordering).foreach { case (t, b) => + logInfo(s"Restoring KafkaRDD for time $t ${b.mkString("[", ", ", "]")}") + generatedRDDs += t -> new KafkaRDD[K, V, U, T, R]( + context.sparkContext, kafkaParams, b.map(OffsetRange(_)), leaders, messageHandler) + } + } + } + +} diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala new file mode 100644 index 0000000000000..2f7e0ab39fefd --- /dev/null +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala @@ -0,0 +1,372 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kafka + +import scala.util.control.NonFatal +import scala.util.Random +import scala.collection.mutable.ArrayBuffer +import java.util.Properties +import kafka.api._ +import kafka.common.{ErrorMapping, OffsetMetadataAndError, TopicAndPartition} +import kafka.consumer.{ConsumerConfig, SimpleConsumer} +import org.apache.spark.SparkException + +/** + * Convenience methods for interacting with a Kafka cluster. + * @param kafkaParams Kafka + * configuration parameters. + * Requires "metadata.broker.list" or "bootstrap.servers" to be set with Kafka broker(s), + * NOT zookeeper servers, specified in host1:port1,host2:port2 form + */ +private[spark] +class KafkaCluster(val kafkaParams: Map[String, String]) extends Serializable { + import KafkaCluster.{Err, LeaderOffset, SimpleConsumerConfig} + + // ConsumerConfig isn't serializable + @transient private var _config: SimpleConsumerConfig = null + + def config: SimpleConsumerConfig = this.synchronized { + if (_config == null) { + _config = SimpleConsumerConfig(kafkaParams) + } + _config + } + + def connect(host: String, port: Int): SimpleConsumer = + new SimpleConsumer(host, port, config.socketTimeoutMs, + config.socketReceiveBufferBytes, config.clientId) + + def connectLeader(topic: String, partition: Int): Either[Err, SimpleConsumer] = + findLeader(topic, partition).right.map(hp => connect(hp._1, hp._2)) + + // Metadata api + // scalastyle:off + // https://cwiki.apache.org/confluence/display/KAFKA/A+Guide+To+The+Kafka+Protocol#AGuideToTheKafkaProtocol-MetadataAPI + // scalastyle:on + + def findLeader(topic: String, partition: Int): Either[Err, (String, Int)] = { + val req = TopicMetadataRequest(TopicMetadataRequest.CurrentVersion, + 0, config.clientId, Seq(topic)) + val errs = new Err + withBrokers(Random.shuffle(config.seedBrokers), errs) { consumer => + val resp: TopicMetadataResponse = consumer.send(req) + resp.topicsMetadata.find(_.topic == topic).flatMap { tm: TopicMetadata => + tm.partitionsMetadata.find(_.partitionId == partition) + }.foreach { pm: PartitionMetadata => + pm.leader.foreach { leader => + return Right((leader.host, leader.port)) + } + } + } + Left(errs) + } + + def findLeaders( + topicAndPartitions: Set[TopicAndPartition] + ): Either[Err, Map[TopicAndPartition, (String, Int)]] = { + val topics = topicAndPartitions.map(_.topic) + val response = getPartitionMetadata(topics).right + val answer = response.flatMap { tms: Set[TopicMetadata] => + val leaderMap = tms.flatMap { tm: TopicMetadata => + tm.partitionsMetadata.flatMap { pm: PartitionMetadata => + val tp = TopicAndPartition(tm.topic, pm.partitionId) + if (topicAndPartitions(tp)) { + pm.leader.map { l => + tp -> (l.host -> l.port) + } + } else { + None + } + } + }.toMap + + if (leaderMap.keys.size == topicAndPartitions.size) { + Right(leaderMap) + } else { + val missing = topicAndPartitions.diff(leaderMap.keySet) + val err = new Err + err.append(new SparkException(s"Couldn't find leaders for ${missing}")) + Left(err) + } + } + answer + } + + def getPartitions(topics: Set[String]): Either[Err, Set[TopicAndPartition]] = { + getPartitionMetadata(topics).right.map { r => + r.flatMap { tm: TopicMetadata => + tm.partitionsMetadata.map { pm: PartitionMetadata => + TopicAndPartition(tm.topic, pm.partitionId) + } + } + } + } + + def getPartitionMetadata(topics: Set[String]): Either[Err, Set[TopicMetadata]] = { + val req = TopicMetadataRequest( + TopicMetadataRequest.CurrentVersion, 0, config.clientId, topics.toSeq) + val errs = new Err + withBrokers(Random.shuffle(config.seedBrokers), errs) { consumer => + val resp: TopicMetadataResponse = consumer.send(req) + // error codes here indicate missing / just created topic, + // repeating on a different broker wont be useful + return Right(resp.topicsMetadata.toSet) + } + Left(errs) + } + + // Leader offset api + // scalastyle:off + // https://cwiki.apache.org/confluence/display/KAFKA/A+Guide+To+The+Kafka+Protocol#AGuideToTheKafkaProtocol-OffsetAPI + // scalastyle:on + + def getLatestLeaderOffsets( + topicAndPartitions: Set[TopicAndPartition] + ): Either[Err, Map[TopicAndPartition, LeaderOffset]] = + getLeaderOffsets(topicAndPartitions, OffsetRequest.LatestTime) + + def getEarliestLeaderOffsets( + topicAndPartitions: Set[TopicAndPartition] + ): Either[Err, Map[TopicAndPartition, LeaderOffset]] = + getLeaderOffsets(topicAndPartitions, OffsetRequest.EarliestTime) + + def getLeaderOffsets( + topicAndPartitions: Set[TopicAndPartition], + before: Long + ): Either[Err, Map[TopicAndPartition, LeaderOffset]] = { + getLeaderOffsets(topicAndPartitions, before, 1).right.map { r => + r.map { kv => + // mapValues isnt serializable, see SI-7005 + kv._1 -> kv._2.head + } + } + } + + private def flip[K, V](m: Map[K, V]): Map[V, Seq[K]] = + m.groupBy(_._2).map { kv => + kv._1 -> kv._2.keys.toSeq + } + + def getLeaderOffsets( + topicAndPartitions: Set[TopicAndPartition], + before: Long, + maxNumOffsets: Int + ): Either[Err, Map[TopicAndPartition, Seq[LeaderOffset]]] = { + findLeaders(topicAndPartitions).right.flatMap { tpToLeader => + val leaderToTp: Map[(String, Int), Seq[TopicAndPartition]] = flip(tpToLeader) + val leaders = leaderToTp.keys + var result = Map[TopicAndPartition, Seq[LeaderOffset]]() + val errs = new Err + withBrokers(leaders, errs) { consumer => + val partitionsToGetOffsets: Seq[TopicAndPartition] = + leaderToTp((consumer.host, consumer.port)) + val reqMap = partitionsToGetOffsets.map { tp: TopicAndPartition => + tp -> PartitionOffsetRequestInfo(before, maxNumOffsets) + }.toMap + val req = OffsetRequest(reqMap) + val resp = consumer.getOffsetsBefore(req) + val respMap = resp.partitionErrorAndOffsets + partitionsToGetOffsets.foreach { tp: TopicAndPartition => + respMap.get(tp).foreach { por: PartitionOffsetsResponse => + if (por.error == ErrorMapping.NoError) { + if (por.offsets.nonEmpty) { + result += tp -> por.offsets.map { off => + LeaderOffset(consumer.host, consumer.port, off) + } + } else { + errs.append(new SparkException( + s"Empty offsets for ${tp}, is ${before} before log beginning?")) + } + } else { + errs.append(ErrorMapping.exceptionFor(por.error)) + } + } + } + if (result.keys.size == topicAndPartitions.size) { + return Right(result) + } + } + val missing = topicAndPartitions.diff(result.keySet) + errs.append(new SparkException(s"Couldn't find leader offsets for ${missing}")) + Left(errs) + } + } + + // Consumer offset api + // scalastyle:off + // https://cwiki.apache.org/confluence/display/KAFKA/A+Guide+To+The+Kafka+Protocol#AGuideToTheKafkaProtocol-OffsetCommit/FetchAPI + // scalastyle:on + + /** Requires Kafka >= 0.8.1.1 */ + def getConsumerOffsets( + groupId: String, + topicAndPartitions: Set[TopicAndPartition] + ): Either[Err, Map[TopicAndPartition, Long]] = { + getConsumerOffsetMetadata(groupId, topicAndPartitions).right.map { r => + r.map { kv => + kv._1 -> kv._2.offset + } + } + } + + /** Requires Kafka >= 0.8.1.1 */ + def getConsumerOffsetMetadata( + groupId: String, + topicAndPartitions: Set[TopicAndPartition] + ): Either[Err, Map[TopicAndPartition, OffsetMetadataAndError]] = { + var result = Map[TopicAndPartition, OffsetMetadataAndError]() + val req = OffsetFetchRequest(groupId, topicAndPartitions.toSeq) + val errs = new Err + withBrokers(Random.shuffle(config.seedBrokers), errs) { consumer => + val resp = consumer.fetchOffsets(req) + val respMap = resp.requestInfo + val needed = topicAndPartitions.diff(result.keySet) + needed.foreach { tp: TopicAndPartition => + respMap.get(tp).foreach { ome: OffsetMetadataAndError => + if (ome.error == ErrorMapping.NoError) { + result += tp -> ome + } else { + errs.append(ErrorMapping.exceptionFor(ome.error)) + } + } + } + if (result.keys.size == topicAndPartitions.size) { + return Right(result) + } + } + val missing = topicAndPartitions.diff(result.keySet) + errs.append(new SparkException(s"Couldn't find consumer offsets for ${missing}")) + Left(errs) + } + + /** Requires Kafka >= 0.8.1.1 */ + def setConsumerOffsets( + groupId: String, + offsets: Map[TopicAndPartition, Long] + ): Either[Err, Map[TopicAndPartition, Short]] = { + setConsumerOffsetMetadata(groupId, offsets.map { kv => + kv._1 -> OffsetMetadataAndError(kv._2) + }) + } + + /** Requires Kafka >= 0.8.1.1 */ + def setConsumerOffsetMetadata( + groupId: String, + metadata: Map[TopicAndPartition, OffsetMetadataAndError] + ): Either[Err, Map[TopicAndPartition, Short]] = { + var result = Map[TopicAndPartition, Short]() + val req = OffsetCommitRequest(groupId, metadata) + val errs = new Err + val topicAndPartitions = metadata.keySet + withBrokers(Random.shuffle(config.seedBrokers), errs) { consumer => + val resp = consumer.commitOffsets(req) + val respMap = resp.requestInfo + val needed = topicAndPartitions.diff(result.keySet) + needed.foreach { tp: TopicAndPartition => + respMap.get(tp).foreach { err: Short => + if (err == ErrorMapping.NoError) { + result += tp -> err + } else { + errs.append(ErrorMapping.exceptionFor(err)) + } + } + } + if (result.keys.size == topicAndPartitions.size) { + return Right(result) + } + } + val missing = topicAndPartitions.diff(result.keySet) + errs.append(new SparkException(s"Couldn't set offsets for ${missing}")) + Left(errs) + } + + // Try a call against potentially multiple brokers, accumulating errors + private def withBrokers(brokers: Iterable[(String, Int)], errs: Err) + (fn: SimpleConsumer => Any): Unit = { + brokers.foreach { hp => + var consumer: SimpleConsumer = null + try { + consumer = connect(hp._1, hp._2) + fn(consumer) + } catch { + case NonFatal(e) => + errs.append(e) + } finally { + if (consumer != null) { + consumer.close() + } + } + } + } +} + +private[spark] +object KafkaCluster { + type Err = ArrayBuffer[Throwable] + + private[spark] + case class LeaderOffset(host: String, port: Int, offset: Long) + + /** + * High-level kafka consumers connect to ZK. ConsumerConfig assumes this use case. + * Simple consumers connect directly to brokers, but need many of the same configs. + * This subclass won't warn about missing ZK params, or presence of broker params. + */ + private[spark] + class SimpleConsumerConfig private(brokers: String, originalProps: Properties) + extends ConsumerConfig(originalProps) { + val seedBrokers: Array[(String, Int)] = brokers.split(",").map { hp => + val hpa = hp.split(":") + if (hpa.size == 1) { + throw new SparkException(s"Broker not the in correct format of : [$brokers]") + } + (hpa(0), hpa(1).toInt) + } + } + + private[spark] + object SimpleConsumerConfig { + /** + * Make a consumer config without requiring group.id or zookeeper.connect, + * since communicating with brokers also needs common settings such as timeout + */ + def apply(kafkaParams: Map[String, String]): SimpleConsumerConfig = { + // These keys are from other pre-existing kafka configs for specifying brokers, accept either + val brokers = kafkaParams.get("metadata.broker.list") + .orElse(kafkaParams.get("bootstrap.servers")) + .getOrElse(throw new SparkException( + "Must specify metadata.broker.list or bootstrap.servers")) + + val props = new Properties() + kafkaParams.foreach { case (key, value) => + // prevent warnings on parameters ConsumerConfig doesn't know about + if (key != "metadata.broker.list" && key != "bootstrap.servers") { + props.put(key, value) + } + } + + Seq("zookeeper.connect", "group.id").foreach { s => + if (!props.contains(s)) { + props.setProperty(s, "") + } + } + + new SimpleConsumerConfig(brokers, props) + } + } +} diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala new file mode 100644 index 0000000000000..d56cc01be9514 --- /dev/null +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala @@ -0,0 +1,222 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kafka + +import scala.reflect.{classTag, ClassTag} + +import org.apache.spark.{Logging, Partition, SparkContext, SparkException, TaskContext} +import org.apache.spark.rdd.RDD +import org.apache.spark.util.NextIterator + +import java.util.Properties +import kafka.api.{FetchRequestBuilder, FetchResponse} +import kafka.common.{ErrorMapping, TopicAndPartition} +import kafka.consumer.{ConsumerConfig, SimpleConsumer} +import kafka.message.{MessageAndMetadata, MessageAndOffset} +import kafka.serializer.Decoder +import kafka.utils.VerifiableProperties + +/** + * A batch-oriented interface for consuming from Kafka. + * Starting and ending offsets are specified in advance, + * so that you can control exactly-once semantics. + * @param kafkaParams Kafka + * configuration parameters. Requires "metadata.broker.list" or "bootstrap.servers" to be set + * with Kafka broker(s) specified in host1:port1,host2:port2 form. + * @param offsetRanges offset ranges that define the Kafka data belonging to this RDD + * @param messageHandler function for translating each message into the desired type + */ +private[kafka] +class KafkaRDD[ + K: ClassTag, + V: ClassTag, + U <: Decoder[_]: ClassTag, + T <: Decoder[_]: ClassTag, + R: ClassTag] private[spark] ( + sc: SparkContext, + kafkaParams: Map[String, String], + val offsetRanges: Array[OffsetRange], + leaders: Map[TopicAndPartition, (String, Int)], + messageHandler: MessageAndMetadata[K, V] => R + ) extends RDD[R](sc, Nil) with Logging with HasOffsetRanges { + override def getPartitions: Array[Partition] = { + offsetRanges.zipWithIndex.map { case (o, i) => + val (host, port) = leaders(TopicAndPartition(o.topic, o.partition)) + new KafkaRDDPartition(i, o.topic, o.partition, o.fromOffset, o.untilOffset, host, port) + }.toArray + } + + override def getPreferredLocations(thePart: Partition): Seq[String] = { + val part = thePart.asInstanceOf[KafkaRDDPartition] + // TODO is additional hostname resolution necessary here + Seq(part.host) + } + + private def errBeginAfterEnd(part: KafkaRDDPartition): String = + s"Beginning offset ${part.fromOffset} is after the ending offset ${part.untilOffset} " + + s"for topic ${part.topic} partition ${part.partition}. " + + "You either provided an invalid fromOffset, or the Kafka topic has been damaged" + + private def errRanOutBeforeEnd(part: KafkaRDDPartition): String = + s"Ran out of messages before reaching ending offset ${part.untilOffset} " + + s"for topic ${part.topic} partition ${part.partition} start ${part.fromOffset}." + + " This should not happen, and indicates that messages may have been lost" + + private def errOvershotEnd(itemOffset: Long, part: KafkaRDDPartition): String = + s"Got ${itemOffset} > ending offset ${part.untilOffset} " + + s"for topic ${part.topic} partition ${part.partition} start ${part.fromOffset}." + + " This should not happen, and indicates a message may have been skipped" + + override def compute(thePart: Partition, context: TaskContext): Iterator[R] = { + val part = thePart.asInstanceOf[KafkaRDDPartition] + assert(part.fromOffset <= part.untilOffset, errBeginAfterEnd(part)) + if (part.fromOffset == part.untilOffset) { + log.warn("Beginning offset ${part.fromOffset} is the same as ending offset " + + s"skipping ${part.topic} ${part.partition}") + Iterator.empty + } else { + new KafkaRDDIterator(part, context) + } + } + + private class KafkaRDDIterator( + part: KafkaRDDPartition, + context: TaskContext) extends NextIterator[R] { + + context.addTaskCompletionListener{ context => closeIfNeeded() } + + log.info(s"Computing topic ${part.topic}, partition ${part.partition} " + + s"offsets ${part.fromOffset} -> ${part.untilOffset}") + + val kc = new KafkaCluster(kafkaParams) + val keyDecoder = classTag[U].runtimeClass.getConstructor(classOf[VerifiableProperties]) + .newInstance(kc.config.props) + .asInstanceOf[Decoder[K]] + val valueDecoder = classTag[T].runtimeClass.getConstructor(classOf[VerifiableProperties]) + .newInstance(kc.config.props) + .asInstanceOf[Decoder[V]] + val consumer = connectLeader + var requestOffset = part.fromOffset + var iter: Iterator[MessageAndOffset] = null + + // The idea is to use the provided preferred host, except on task retry atttempts, + // to minimize number of kafka metadata requests + private def connectLeader: SimpleConsumer = { + if (context.attemptNumber > 0) { + kc.connectLeader(part.topic, part.partition).fold( + errs => throw new SparkException( + s"Couldn't connect to leader for topic ${part.topic} ${part.partition}: " + + errs.mkString("\n")), + consumer => consumer + ) + } else { + kc.connect(part.host, part.port) + } + } + + private def handleFetchErr(resp: FetchResponse) { + if (resp.hasError) { + val err = resp.errorCode(part.topic, part.partition) + if (err == ErrorMapping.LeaderNotAvailableCode || + err == ErrorMapping.NotLeaderForPartitionCode) { + log.error(s"Lost leader for topic ${part.topic} partition ${part.partition}, " + + s" sleeping for ${kc.config.refreshLeaderBackoffMs}ms") + Thread.sleep(kc.config.refreshLeaderBackoffMs) + } + // Let normal rdd retry sort out reconnect attempts + throw ErrorMapping.exceptionFor(err) + } + } + + private def fetchBatch: Iterator[MessageAndOffset] = { + val req = new FetchRequestBuilder() + .addFetch(part.topic, part.partition, requestOffset, kc.config.fetchMessageMaxBytes) + .build() + val resp = consumer.fetch(req) + handleFetchErr(resp) + // kafka may return a batch that starts before the requested offset + resp.messageSet(part.topic, part.partition) + .iterator + .dropWhile(_.offset < requestOffset) + } + + override def close() = consumer.close() + + override def getNext(): R = { + if (iter == null || !iter.hasNext) { + iter = fetchBatch + } + if (!iter.hasNext) { + assert(requestOffset == part.untilOffset, errRanOutBeforeEnd(part)) + finished = true + null.asInstanceOf[R] + } else { + val item = iter.next() + if (item.offset >= part.untilOffset) { + assert(item.offset == part.untilOffset, errOvershotEnd(item.offset, part)) + finished = true + null.asInstanceOf[R] + } else { + requestOffset = item.nextOffset + messageHandler(new MessageAndMetadata( + part.topic, part.partition, item.message, item.offset, keyDecoder, valueDecoder)) + } + } + } + } +} + +private[kafka] +object KafkaRDD { + import KafkaCluster.LeaderOffset + + /** + * @param kafkaParams Kafka + * configuration parameters. + * Requires "metadata.broker.list" or "bootstrap.servers" to be set with Kafka broker(s), + * NOT zookeeper servers, specified in host1:port1,host2:port2 form. + * @param fromOffsets per-topic/partition Kafka offsets defining the (inclusive) + * starting point of the batch + * @param untilOffsets per-topic/partition Kafka offsets defining the (exclusive) + * ending point of the batch + * @param messageHandler function for translating each message into the desired type + */ + def apply[ + K: ClassTag, + V: ClassTag, + U <: Decoder[_]: ClassTag, + T <: Decoder[_]: ClassTag, + R: ClassTag]( + sc: SparkContext, + kafkaParams: Map[String, String], + fromOffsets: Map[TopicAndPartition, Long], + untilOffsets: Map[TopicAndPartition, LeaderOffset], + messageHandler: MessageAndMetadata[K, V] => R + ): KafkaRDD[K, V, U, T, R] = { + val leaders = untilOffsets.map { case (tp, lo) => + tp -> (lo.host, lo.port) + }.toMap + + val offsetRanges = fromOffsets.map { case (tp, fo) => + val uo = untilOffsets(tp) + OffsetRange(tp.topic, tp.partition, fo, uo.offset) + }.toArray + + new KafkaRDD[K, V, U, T, R](sc, kafkaParams, offsetRanges, leaders, messageHandler) + } +} diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDDPartition.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDDPartition.scala new file mode 100644 index 0000000000000..a842a6f17766f --- /dev/null +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDDPartition.scala @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kafka + +import org.apache.spark.Partition + +/** @param topic kafka topic name + * @param partition kafka partition id + * @param fromOffset inclusive starting offset + * @param untilOffset exclusive ending offset + * @param host preferred kafka host, i.e. the leader at the time the rdd was created + * @param port preferred kafka host's port + */ +private[kafka] +class KafkaRDDPartition( + val index: Int, + val topic: String, + val partition: Int, + val fromOffset: Long, + val untilOffset: Long, + val host: String, + val port: Int +) extends Partition diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala index df725f0c65a64..af04bc6576148 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala @@ -18,21 +18,30 @@ package org.apache.spark.streaming.kafka import java.lang.{Integer => JInt} +import java.lang.{Long => JLong} import java.util.{Map => JMap} +import java.util.{Set => JSet} import scala.reflect.ClassTag import scala.collection.JavaConversions._ +import kafka.common.TopicAndPartition +import kafka.message.MessageAndMetadata import kafka.serializer.{Decoder, StringDecoder} +import org.apache.spark.api.java.function.{Function => JFunction} +import org.apache.spark.{SparkContext, SparkException} +import org.apache.spark.annotation.Experimental +import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.StreamingContext -import org.apache.spark.streaming.api.java.{JavaPairReceiverInputDStream, JavaStreamingContext} -import org.apache.spark.streaming.dstream.ReceiverInputDStream +import org.apache.spark.streaming.api.java.{JavaPairInputDStream, JavaInputDStream, JavaPairReceiverInputDStream, JavaStreamingContext} +import org.apache.spark.streaming.dstream.{InputDStream, ReceiverInputDStream} +import org.apache.spark.api.java.{JavaSparkContext, JavaPairRDD, JavaRDD} object KafkaUtils { /** - * Create an input stream that pulls messages from a Kafka Broker. + * Create an input stream that pulls messages from Kafka Brokers. * @param ssc StreamingContext object * @param zkQuorum Zookeeper quorum (hostname:port,hostname:port,..) * @param groupId The group id for this consumer @@ -56,7 +65,7 @@ object KafkaUtils { } /** - * Create an input stream that pulls messages from a Kafka Broker. + * Create an input stream that pulls messages from Kafka Brokers. * @param ssc StreamingContext object * @param kafkaParams Map of kafka configuration parameters, * see http://kafka.apache.org/08/configuration.html @@ -75,7 +84,7 @@ object KafkaUtils { } /** - * Create an input stream that pulls messages from a Kafka Broker. + * Create an input stream that pulls messages from Kafka Brokers. * Storage level of the data will be the default StorageLevel.MEMORY_AND_DISK_SER_2. * @param jssc JavaStreamingContext object * @param zkQuorum Zookeeper quorum (hostname:port,hostname:port,..) @@ -93,7 +102,7 @@ object KafkaUtils { } /** - * Create an input stream that pulls messages from a Kafka Broker. + * Create an input stream that pulls messages from Kafka Brokers. * @param jssc JavaStreamingContext object * @param zkQuorum Zookeeper quorum (hostname:port,hostname:port,..). * @param groupId The group id for this consumer. @@ -113,10 +122,10 @@ object KafkaUtils { } /** - * Create an input stream that pulls messages from a Kafka Broker. + * Create an input stream that pulls messages from Kafka Brokers. * @param jssc JavaStreamingContext object - * @param keyTypeClass Key type of RDD - * @param valueTypeClass value type of RDD + * @param keyTypeClass Key type of DStream + * @param valueTypeClass value type of Dstream * @param keyDecoderClass Type of kafka key decoder * @param valueDecoderClass Type of kafka value decoder * @param kafkaParams Map of kafka configuration parameters, @@ -144,4 +153,382 @@ object KafkaUtils { createStream[K, V, U, T]( jssc.ssc, kafkaParams.toMap, Map(topics.mapValues(_.intValue()).toSeq: _*), storageLevel) } + + /** get leaders for the given offset ranges, or throw an exception */ + private def leadersForRanges( + kafkaParams: Map[String, String], + offsetRanges: Array[OffsetRange]): Map[TopicAndPartition, (String, Int)] = { + val kc = new KafkaCluster(kafkaParams) + val topics = offsetRanges.map(o => TopicAndPartition(o.topic, o.partition)).toSet + val leaders = kc.findLeaders(topics).fold( + errs => throw new SparkException(errs.mkString("\n")), + ok => ok + ) + leaders + } + + /** + * Create a RDD from Kafka using offset ranges for each topic and partition. + * + * @param sc SparkContext object + * @param kafkaParams Kafka + * configuration parameters. Requires "metadata.broker.list" or "bootstrap.servers" + * to be set with Kafka broker(s) (NOT zookeeper servers) specified in + * host1:port1,host2:port2 form. + * @param offsetRanges Each OffsetRange in the batch corresponds to a + * range of offsets for a given Kafka topic/partition + */ + @Experimental + def createRDD[ + K: ClassTag, + V: ClassTag, + KD <: Decoder[K]: ClassTag, + VD <: Decoder[V]: ClassTag]( + sc: SparkContext, + kafkaParams: Map[String, String], + offsetRanges: Array[OffsetRange] + ): RDD[(K, V)] = { + val messageHandler = (mmd: MessageAndMetadata[K, V]) => (mmd.key, mmd.message) + val leaders = leadersForRanges(kafkaParams, offsetRanges) + new KafkaRDD[K, V, KD, VD, (K, V)](sc, kafkaParams, offsetRanges, leaders, messageHandler) + } + + /** + * :: Experimental :: + * Create a RDD from Kafka using offset ranges for each topic and partition. This allows you + * specify the Kafka leader to connect to (to optimize fetching) and access the message as well + * as the metadata. + * + * @param sc SparkContext object + * @param kafkaParams Kafka + * configuration parameters. Requires "metadata.broker.list" or "bootstrap.servers" + * to be set with Kafka broker(s) (NOT zookeeper servers) specified in + * host1:port1,host2:port2 form. + * @param offsetRanges Each OffsetRange in the batch corresponds to a + * range of offsets for a given Kafka topic/partition + * @param leaders Kafka brokers for each TopicAndPartition in offsetRanges. May be an empty map, + * in which case leaders will be looked up on the driver. + * @param messageHandler Function for translating each message and metadata into the desired type + */ + @Experimental + def createRDD[ + K: ClassTag, + V: ClassTag, + KD <: Decoder[K]: ClassTag, + VD <: Decoder[V]: ClassTag, + R: ClassTag]( + sc: SparkContext, + kafkaParams: Map[String, String], + offsetRanges: Array[OffsetRange], + leaders: Map[TopicAndPartition, Broker], + messageHandler: MessageAndMetadata[K, V] => R + ): RDD[R] = { + val leaderMap = if (leaders.isEmpty) { + leadersForRanges(kafkaParams, offsetRanges) + } else { + // This could be avoided by refactoring KafkaRDD.leaders and KafkaCluster to use Broker + leaders.map { + case (tp: TopicAndPartition, Broker(host, port)) => (tp, (host, port)) + }.toMap + } + new KafkaRDD[K, V, KD, VD, R](sc, kafkaParams, offsetRanges, leaderMap, messageHandler) + } + + + /** + * Create a RDD from Kafka using offset ranges for each topic and partition. + * + * @param jsc JavaSparkContext object + * @param kafkaParams Kafka + * configuration parameters. Requires "metadata.broker.list" or "bootstrap.servers" + * to be set with Kafka broker(s) (NOT zookeeper servers) specified in + * host1:port1,host2:port2 form. + * @param offsetRanges Each OffsetRange in the batch corresponds to a + * range of offsets for a given Kafka topic/partition + */ + @Experimental + def createRDD[K, V, KD <: Decoder[K], VD <: Decoder[V]]( + jsc: JavaSparkContext, + keyClass: Class[K], + valueClass: Class[V], + keyDecoderClass: Class[KD], + valueDecoderClass: Class[VD], + kafkaParams: JMap[String, String], + offsetRanges: Array[OffsetRange] + ): JavaPairRDD[K, V] = { + implicit val keyCmt: ClassTag[K] = ClassTag(keyClass) + implicit val valueCmt: ClassTag[V] = ClassTag(valueClass) + implicit val keyDecoderCmt: ClassTag[KD] = ClassTag(keyDecoderClass) + implicit val valueDecoderCmt: ClassTag[VD] = ClassTag(valueDecoderClass) + new JavaPairRDD(createRDD[K, V, KD, VD]( + jsc.sc, Map(kafkaParams.toSeq: _*), offsetRanges)) + } + + /** + * :: Experimental :: + * Create a RDD from Kafka using offset ranges for each topic and partition. This allows you + * specify the Kafka leader to connect to (to optimize fetching) and access the message as well + * as the metadata. + * + * @param jsc JavaSparkContext object + * @param kafkaParams Kafka + * configuration parameters. Requires "metadata.broker.list" or "bootstrap.servers" + * to be set with Kafka broker(s) (NOT zookeeper servers) specified in + * host1:port1,host2:port2 form. + * @param offsetRanges Each OffsetRange in the batch corresponds to a + * range of offsets for a given Kafka topic/partition + * @param leaders Kafka brokers for each TopicAndPartition in offsetRanges. May be an empty map, + * in which case leaders will be looked up on the driver. + * @param messageHandler Function for translating each message and metadata into the desired type + */ + @Experimental + def createRDD[K, V, KD <: Decoder[K], VD <: Decoder[V], R]( + jsc: JavaSparkContext, + keyClass: Class[K], + valueClass: Class[V], + keyDecoderClass: Class[KD], + valueDecoderClass: Class[VD], + recordClass: Class[R], + kafkaParams: JMap[String, String], + offsetRanges: Array[OffsetRange], + leaders: JMap[TopicAndPartition, Broker], + messageHandler: JFunction[MessageAndMetadata[K, V], R] + ): JavaRDD[R] = { + implicit val keyCmt: ClassTag[K] = ClassTag(keyClass) + implicit val valueCmt: ClassTag[V] = ClassTag(valueClass) + implicit val keyDecoderCmt: ClassTag[KD] = ClassTag(keyDecoderClass) + implicit val valueDecoderCmt: ClassTag[VD] = ClassTag(valueDecoderClass) + implicit val recordCmt: ClassTag[R] = ClassTag(recordClass) + val leaderMap = Map(leaders.toSeq: _*) + createRDD[K, V, KD, VD, R]( + jsc.sc, Map(kafkaParams.toSeq: _*), offsetRanges, leaderMap, messageHandler.call _) + } + + /** + * :: Experimental :: + * Create an input stream that directly pulls messages from Kafka Brokers + * without using any receiver. This stream can guarantee that each message + * from Kafka is included in transformations exactly once (see points below). + * + * Points to note: + * - No receivers: This stream does not use any receiver. It directly queries Kafka + * - Offsets: This does not use Zookeeper to store offsets. The consumed offsets are tracked + * by the stream itself. For interoperability with Kafka monitoring tools that depend on + * Zookeeper, you have to update Kafka/Zookeeper yourself from the streaming application. + * You can access the offsets used in each batch from the generated RDDs (see + * [[org.apache.spark.streaming.kafka.HasOffsetRanges]]). + * - Failure Recovery: To recover from driver failures, you have to enable checkpointing + * in the [[StreamingContext]]. The information on consumed offset can be + * recovered from the checkpoint. See the programming guide for details (constraints, etc.). + * - End-to-end semantics: This stream ensures that every records is effectively received and + * transformed exactly once, but gives no guarantees on whether the transformed data are + * outputted exactly once. For end-to-end exactly-once semantics, you have to either ensure + * that the output operation is idempotent, or use transactions to output records atomically. + * See the programming guide for more details. + * + * @param ssc StreamingContext object + * @param kafkaParams Kafka + * configuration parameters. Requires "metadata.broker.list" or "bootstrap.servers" + * to be set with Kafka broker(s) (NOT zookeeper servers) specified in + * host1:port1,host2:port2 form. + * @param fromOffsets Per-topic/partition Kafka offsets defining the (inclusive) + * starting point of the stream + * @param messageHandler Function for translating each message and metadata into the desired type + */ + @Experimental + def createDirectStream[ + K: ClassTag, + V: ClassTag, + KD <: Decoder[K]: ClassTag, + VD <: Decoder[V]: ClassTag, + R: ClassTag] ( + ssc: StreamingContext, + kafkaParams: Map[String, String], + fromOffsets: Map[TopicAndPartition, Long], + messageHandler: MessageAndMetadata[K, V] => R + ): InputDStream[R] = { + new DirectKafkaInputDStream[K, V, KD, VD, R]( + ssc, kafkaParams, fromOffsets, messageHandler) + } + + /** + * :: Experimental :: + * Create an input stream that directly pulls messages from Kafka Brokers + * without using any receiver. This stream can guarantee that each message + * from Kafka is included in transformations exactly once (see points below). + * + * Points to note: + * - No receivers: This stream does not use any receiver. It directly queries Kafka + * - Offsets: This does not use Zookeeper to store offsets. The consumed offsets are tracked + * by the stream itself. For interoperability with Kafka monitoring tools that depend on + * Zookeeper, you have to update Kafka/Zookeeper yourself from the streaming application. + * You can access the offsets used in each batch from the generated RDDs (see + * [[org.apache.spark.streaming.kafka.HasOffsetRanges]]). + * - Failure Recovery: To recover from driver failures, you have to enable checkpointing + * in the [[StreamingContext]]. The information on consumed offset can be + * recovered from the checkpoint. See the programming guide for details (constraints, etc.). + * - End-to-end semantics: This stream ensures that every records is effectively received and + * transformed exactly once, but gives no guarantees on whether the transformed data are + * outputted exactly once. For end-to-end exactly-once semantics, you have to either ensure + * that the output operation is idempotent, or use transactions to output records atomically. + * See the programming guide for more details. + * + * @param ssc StreamingContext object + * @param kafkaParams Kafka + * configuration parameters. Requires "metadata.broker.list" or "bootstrap.servers" + * to be set with Kafka broker(s) (NOT zookeeper servers), specified in + * host1:port1,host2:port2 form. + * If not starting from a checkpoint, "auto.offset.reset" may be set to "largest" or "smallest" + * to determine where the stream starts (defaults to "largest") + * @param topics Names of the topics to consume + */ + @Experimental + def createDirectStream[ + K: ClassTag, + V: ClassTag, + KD <: Decoder[K]: ClassTag, + VD <: Decoder[V]: ClassTag] ( + ssc: StreamingContext, + kafkaParams: Map[String, String], + topics: Set[String] + ): InputDStream[(K, V)] = { + val messageHandler = (mmd: MessageAndMetadata[K, V]) => (mmd.key, mmd.message) + val kc = new KafkaCluster(kafkaParams) + val reset = kafkaParams.get("auto.offset.reset").map(_.toLowerCase) + + (for { + topicPartitions <- kc.getPartitions(topics).right + leaderOffsets <- (if (reset == Some("smallest")) { + kc.getEarliestLeaderOffsets(topicPartitions) + } else { + kc.getLatestLeaderOffsets(topicPartitions) + }).right + } yield { + val fromOffsets = leaderOffsets.map { case (tp, lo) => + (tp, lo.offset) + } + new DirectKafkaInputDStream[K, V, KD, VD, (K, V)]( + ssc, kafkaParams, fromOffsets, messageHandler) + }).fold( + errs => throw new SparkException(errs.mkString("\n")), + ok => ok + ) + } + + /** + * :: Experimental :: + * Create an input stream that directly pulls messages from Kafka Brokers + * without using any receiver. This stream can guarantee that each message + * from Kafka is included in transformations exactly once (see points below). + * + * Points to note: + * - No receivers: This stream does not use any receiver. It directly queries Kafka + * - Offsets: This does not use Zookeeper to store offsets. The consumed offsets are tracked + * by the stream itself. For interoperability with Kafka monitoring tools that depend on + * Zookeeper, you have to update Kafka/Zookeeper yourself from the streaming application. + * You can access the offsets used in each batch from the generated RDDs (see + * [[org.apache.spark.streaming.kafka.HasOffsetRanges]]). + * - Failure Recovery: To recover from driver failures, you have to enable checkpointing + * in the [[StreamingContext]]. The information on consumed offset can be + * recovered from the checkpoint. See the programming guide for details (constraints, etc.). + * - End-to-end semantics: This stream ensures that every records is effectively received and + * transformed exactly once, but gives no guarantees on whether the transformed data are + * outputted exactly once. For end-to-end exactly-once semantics, you have to either ensure + * that the output operation is idempotent, or use transactions to output records atomically. + * See the programming guide for more details. + * + * @param jssc JavaStreamingContext object + * @param keyClass Class of the keys in the Kafka records + * @param valueClass Class of the values in the Kafka records + * @param keyDecoderClass Class of the key decoder + * @param valueDecoderClass Class of the value decoder + * @param recordClass Class of the records in DStream + * @param kafkaParams Kafka + * configuration parameters. Requires "metadata.broker.list" or "bootstrap.servers" + * to be set with Kafka broker(s) (NOT zookeeper servers), specified in + * host1:port1,host2:port2 form. + * @param fromOffsets Per-topic/partition Kafka offsets defining the (inclusive) + * starting point of the stream + * @param messageHandler Function for translating each message and metadata into the desired type + */ + @Experimental + def createDirectStream[K, V, KD <: Decoder[K], VD <: Decoder[V], R]( + jssc: JavaStreamingContext, + keyClass: Class[K], + valueClass: Class[V], + keyDecoderClass: Class[KD], + valueDecoderClass: Class[VD], + recordClass: Class[R], + kafkaParams: JMap[String, String], + fromOffsets: JMap[TopicAndPartition, JLong], + messageHandler: JFunction[MessageAndMetadata[K, V], R] + ): JavaInputDStream[R] = { + implicit val keyCmt: ClassTag[K] = ClassTag(keyClass) + implicit val valueCmt: ClassTag[V] = ClassTag(valueClass) + implicit val keyDecoderCmt: ClassTag[KD] = ClassTag(keyDecoderClass) + implicit val valueDecoderCmt: ClassTag[VD] = ClassTag(valueDecoderClass) + implicit val recordCmt: ClassTag[R] = ClassTag(recordClass) + createDirectStream[K, V, KD, VD, R]( + jssc.ssc, + Map(kafkaParams.toSeq: _*), + Map(fromOffsets.mapValues { _.longValue() }.toSeq: _*), + messageHandler.call _ + ) + } + + /** + * :: Experimental :: + * Create an input stream that directly pulls messages from Kafka Brokers + * without using any receiver. This stream can guarantee that each message + * from Kafka is included in transformations exactly once (see points below). + * + * Points to note: + * - No receivers: This stream does not use any receiver. It directly queries Kafka + * - Offsets: This does not use Zookeeper to store offsets. The consumed offsets are tracked + * by the stream itself. For interoperability with Kafka monitoring tools that depend on + * Zookeeper, you have to update Kafka/Zookeeper yourself from the streaming application. + * You can access the offsets used in each batch from the generated RDDs (see + * [[org.apache.spark.streaming.kafka.HasOffsetRanges]]). + * - Failure Recovery: To recover from driver failures, you have to enable checkpointing + * in the [[StreamingContext]]. The information on consumed offset can be + * recovered from the checkpoint. See the programming guide for details (constraints, etc.). + * - End-to-end semantics: This stream ensures that every records is effectively received and + * transformed exactly once, but gives no guarantees on whether the transformed data are + * outputted exactly once. For end-to-end exactly-once semantics, you have to either ensure + * that the output operation is idempotent, or use transactions to output records atomically. + * See the programming guide for more details. + * + * @param jssc JavaStreamingContext object + * @param keyClass Class of the keys in the Kafka records + * @param valueClass Class of the values in the Kafka records + * @param keyDecoderClass Class of the key decoder + * @param valueDecoderClass Class type of the value decoder + * @param kafkaParams Kafka + * configuration parameters. Requires "metadata.broker.list" or "bootstrap.servers" + * to be set with Kafka broker(s) (NOT zookeeper servers), specified in + * host1:port1,host2:port2 form. + * If not starting from a checkpoint, "auto.offset.reset" may be set to "largest" or "smallest" + * to determine where the stream starts (defaults to "largest") + * @param topics Names of the topics to consume + */ + @Experimental + def createDirectStream[K, V, KD <: Decoder[K], VD <: Decoder[V], R]( + jssc: JavaStreamingContext, + keyClass: Class[K], + valueClass: Class[V], + keyDecoderClass: Class[KD], + valueDecoderClass: Class[VD], + kafkaParams: JMap[String, String], + topics: JSet[String] + ): JavaPairInputDStream[K, V] = { + implicit val keyCmt: ClassTag[K] = ClassTag(keyClass) + implicit val valueCmt: ClassTag[V] = ClassTag(valueClass) + implicit val keyDecoderCmt: ClassTag[KD] = ClassTag(keyDecoderClass) + implicit val valueDecoderCmt: ClassTag[VD] = ClassTag(valueDecoderClass) + createDirectStream[K, V, KD, VD]( + jssc.ssc, + Map(kafkaParams.toSeq: _*), + Set(topics.toSeq: _*) + ) + } } diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/OffsetRange.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/OffsetRange.scala new file mode 100644 index 0000000000000..9c3dfeb8f5928 --- /dev/null +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/OffsetRange.scala @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kafka + +import kafka.common.TopicAndPartition + +import org.apache.spark.annotation.Experimental + +/** + * :: Experimental :: + * Represents any object that has a collection of [[OffsetRange]]s. This can be used access the + * offset ranges in RDDs generated by the direct Kafka DStream (see + * [[KafkaUtils.createDirectStream()]]). + * {{{ + * KafkaUtils.createDirectStream(...).foreachRDD { rdd => + * val offsetRanges = rdd.asInstanceOf[HasOffsetRanges].offsetRanges + * ... + * } + * }}} + */ +@Experimental +trait HasOffsetRanges { + def offsetRanges: Array[OffsetRange] +} + +/** + * :: Experimental :: + * Represents a range of offsets from a single Kafka TopicAndPartition. Instances of this class + * can be created with `OffsetRange.create()`. + */ +@Experimental +final class OffsetRange private( + /** Kafka topic name */ + val topic: String, + /** Kafka partition id */ + val partition: Int, + /** inclusive starting offset */ + val fromOffset: Long, + /** exclusive ending offset */ + val untilOffset: Long) extends Serializable { + import OffsetRange.OffsetRangeTuple + + override def equals(obj: Any): Boolean = obj match { + case that: OffsetRange => + this.topic == that.topic && + this.partition == that.partition && + this.fromOffset == that.fromOffset && + this.untilOffset == that.untilOffset + case _ => false + } + + override def hashCode(): Int = { + toTuple.hashCode() + } + + override def toString(): String = { + s"OffsetRange(topic: '$topic', partition: $partition, range: [$fromOffset -> $untilOffset]" + } + + /** this is to avoid ClassNotFoundException during checkpoint restore */ + private[streaming] + def toTuple: OffsetRangeTuple = (topic, partition, fromOffset, untilOffset) +} + +/** + * :: Experimental :: + * Companion object the provides methods to create instances of [[OffsetRange]]. + */ +@Experimental +object OffsetRange { + def create(topic: String, partition: Int, fromOffset: Long, untilOffset: Long): OffsetRange = + new OffsetRange(topic, partition, fromOffset, untilOffset) + + def create( + topicAndPartition: TopicAndPartition, + fromOffset: Long, + untilOffset: Long): OffsetRange = + new OffsetRange(topicAndPartition.topic, topicAndPartition.partition, fromOffset, untilOffset) + + def apply(topic: String, partition: Int, fromOffset: Long, untilOffset: Long): OffsetRange = + new OffsetRange(topic, partition, fromOffset, untilOffset) + + def apply( + topicAndPartition: TopicAndPartition, + fromOffset: Long, + untilOffset: Long): OffsetRange = + new OffsetRange(topicAndPartition.topic, topicAndPartition.partition, fromOffset, untilOffset) + + /** this is to avoid ClassNotFoundException during checkpoint restore */ + private[kafka] + type OffsetRangeTuple = (String, Int, Long, Long) + + private[kafka] + def apply(t: OffsetRangeTuple) = + new OffsetRange(t._1, t._2, t._3, t._4) +} diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/ReliableKafkaReceiver.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/ReliableKafkaReceiver.scala index be734b80272d1..c4a44c1822c39 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/ReliableKafkaReceiver.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/ReliableKafkaReceiver.scala @@ -201,12 +201,31 @@ class ReliableKafkaReceiver[ topicPartitionOffsetMap.clear() } - /** Store the ready-to-be-stored block and commit the related offsets to zookeeper. */ + /** + * Store the ready-to-be-stored block and commit the related offsets to zookeeper. This method + * will try a fixed number of times to push the block. If the push fails, the receiver is stopped. + */ private def storeBlockAndCommitOffset( blockId: StreamBlockId, arrayBuffer: mutable.ArrayBuffer[_]): Unit = { - store(arrayBuffer.asInstanceOf[mutable.ArrayBuffer[(K, V)]]) - Option(blockOffsetMap.get(blockId)).foreach(commitOffset) - blockOffsetMap.remove(blockId) + var count = 0 + var pushed = false + var exception: Exception = null + while (!pushed && count <= 3) { + try { + store(arrayBuffer.asInstanceOf[mutable.ArrayBuffer[(K, V)]]) + pushed = true + } catch { + case ex: Exception => + count += 1 + exception = ex + } + } + if (pushed) { + Option(blockOffsetMap.get(blockId)).foreach(commitOffset) + blockOffsetMap.remove(blockId) + } else { + stop("Error while storing block into Spark", exception) + } } /** diff --git a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaDirectKafkaStreamSuite.java b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaDirectKafkaStreamSuite.java new file mode 100644 index 0000000000000..1334cc8fd1b57 --- /dev/null +++ b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaDirectKafkaStreamSuite.java @@ -0,0 +1,159 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kafka; + +import java.io.Serializable; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Random; +import java.util.Arrays; + +import org.apache.spark.SparkConf; + +import scala.Tuple2; + +import junit.framework.Assert; + +import kafka.common.TopicAndPartition; +import kafka.message.MessageAndMetadata; +import kafka.serializer.StringDecoder; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.streaming.api.java.JavaDStream; +import org.apache.spark.streaming.Durations; +import org.apache.spark.streaming.api.java.JavaStreamingContext; + +import org.junit.Test; +import org.junit.After; +import org.junit.Before; + +public class JavaDirectKafkaStreamSuite implements Serializable { + private transient JavaStreamingContext ssc = null; + private transient Random random = new Random(); + private transient KafkaStreamSuiteBase suiteBase = null; + + @Before + public void setUp() { + suiteBase = new KafkaStreamSuiteBase() { }; + suiteBase.setupKafka(); + System.clearProperty("spark.driver.port"); + SparkConf sparkConf = new SparkConf() + .setMaster("local[4]").setAppName(this.getClass().getSimpleName()); + ssc = new JavaStreamingContext(sparkConf, Durations.milliseconds(200)); + } + + @After + public void tearDown() { + ssc.stop(); + ssc = null; + System.clearProperty("spark.driver.port"); + suiteBase.tearDownKafka(); + } + + @Test + public void testKafkaStream() throws InterruptedException { + String topic1 = "topic1"; + String topic2 = "topic2"; + + String[] topic1data = createTopicAndSendData(topic1); + String[] topic2data = createTopicAndSendData(topic2); + + HashSet sent = new HashSet(); + sent.addAll(Arrays.asList(topic1data)); + sent.addAll(Arrays.asList(topic2data)); + + HashMap kafkaParams = new HashMap(); + kafkaParams.put("metadata.broker.list", suiteBase.brokerAddress()); + kafkaParams.put("auto.offset.reset", "smallest"); + + JavaDStream stream1 = KafkaUtils.createDirectStream( + ssc, + String.class, + String.class, + StringDecoder.class, + StringDecoder.class, + kafkaParams, + topicToSet(topic1) + ).map( + new Function, String>() { + @Override + public String call(scala.Tuple2 kv) throws Exception { + return kv._2(); + } + } + ); + + JavaDStream stream2 = KafkaUtils.createDirectStream( + ssc, + String.class, + String.class, + StringDecoder.class, + StringDecoder.class, + String.class, + kafkaParams, + topicOffsetToMap(topic2, (long) 0), + new Function, String>() { + @Override + public String call(MessageAndMetadata msgAndMd) throws Exception { + return msgAndMd.message(); + } + } + ); + JavaDStream unifiedStream = stream1.union(stream2); + + final HashSet result = new HashSet(); + unifiedStream.foreachRDD( + new Function, Void>() { + @Override + public Void call(org.apache.spark.api.java.JavaRDD rdd) throws Exception { + result.addAll(rdd.collect()); + return null; + } + } + ); + ssc.start(); + long startTime = System.currentTimeMillis(); + boolean matches = false; + while (!matches && System.currentTimeMillis() - startTime < 20000) { + matches = sent.size() == result.size(); + Thread.sleep(50); + } + Assert.assertEquals(sent, result); + ssc.stop(); + } + + private HashSet topicToSet(String topic) { + HashSet topicSet = new HashSet(); + topicSet.add(topic); + return topicSet; + } + + private HashMap topicOffsetToMap(String topic, Long offsetToStart) { + HashMap topicMap = new HashMap(); + topicMap.put(new TopicAndPartition(topic, 0), offsetToStart); + return topicMap; + } + + private String[] createTopicAndSendData(String topic) { + String[] data = { topic + "-1", topic + "-2", topic + "-3"}; + suiteBase.createTopic(topic); + suiteBase.sendMessages(topic, data); + return data; + } +} diff --git a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaRDDSuite.java b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaRDDSuite.java new file mode 100644 index 0000000000000..9d2e1705c6c73 --- /dev/null +++ b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaRDDSuite.java @@ -0,0 +1,156 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kafka; + +import java.io.Serializable; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Arrays; + +import org.apache.spark.SparkConf; + +import scala.Tuple2; + +import junit.framework.Assert; + +import kafka.common.TopicAndPartition; +import kafka.message.MessageAndMetadata; +import kafka.serializer.StringDecoder; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.Function; + +import org.junit.Test; +import org.junit.After; +import org.junit.Before; + +public class JavaKafkaRDDSuite implements Serializable { + private transient JavaSparkContext sc = null; + private transient KafkaStreamSuiteBase suiteBase = null; + + @Before + public void setUp() { + suiteBase = new KafkaStreamSuiteBase() { }; + suiteBase.setupKafka(); + System.clearProperty("spark.driver.port"); + SparkConf sparkConf = new SparkConf() + .setMaster("local[4]").setAppName(this.getClass().getSimpleName()); + sc = new JavaSparkContext(sparkConf); + } + + @After + public void tearDown() { + sc.stop(); + sc = null; + System.clearProperty("spark.driver.port"); + suiteBase.tearDownKafka(); + } + + @Test + public void testKafkaRDD() throws InterruptedException { + String topic1 = "topic1"; + String topic2 = "topic2"; + + String[] topic1data = createTopicAndSendData(topic1); + String[] topic2data = createTopicAndSendData(topic2); + + HashMap kafkaParams = new HashMap(); + kafkaParams.put("metadata.broker.list", suiteBase.brokerAddress()); + + OffsetRange[] offsetRanges = { + OffsetRange.create(topic1, 0, 0, 1), + OffsetRange.create(topic2, 0, 0, 1) + }; + + HashMap emptyLeaders = new HashMap(); + HashMap leaders = new HashMap(); + String[] hostAndPort = suiteBase.brokerAddress().split(":"); + Broker broker = Broker.create(hostAndPort[0], Integer.parseInt(hostAndPort[1])); + leaders.put(new TopicAndPartition(topic1, 0), broker); + leaders.put(new TopicAndPartition(topic2, 0), broker); + + JavaRDD rdd1 = KafkaUtils.createRDD( + sc, + String.class, + String.class, + StringDecoder.class, + StringDecoder.class, + kafkaParams, + offsetRanges + ).map( + new Function, String>() { + @Override + public String call(scala.Tuple2 kv) throws Exception { + return kv._2(); + } + } + ); + + JavaRDD rdd2 = KafkaUtils.createRDD( + sc, + String.class, + String.class, + StringDecoder.class, + StringDecoder.class, + String.class, + kafkaParams, + offsetRanges, + emptyLeaders, + new Function, String>() { + @Override + public String call(MessageAndMetadata msgAndMd) throws Exception { + return msgAndMd.message(); + } + } + ); + + JavaRDD rdd3 = KafkaUtils.createRDD( + sc, + String.class, + String.class, + StringDecoder.class, + StringDecoder.class, + String.class, + kafkaParams, + offsetRanges, + leaders, + new Function, String>() { + @Override + public String call(MessageAndMetadata msgAndMd) throws Exception { + return msgAndMd.message(); + } + } + ); + + // just making sure the java user apis work; the scala tests handle logic corner cases + long count1 = rdd1.count(); + long count2 = rdd2.count(); + long count3 = rdd3.count(); + Assert.assertTrue(count1 > 0); + Assert.assertEquals(count1, count2); + Assert.assertEquals(count1, count3); + } + + private String[] createTopicAndSendData(String topic) { + String[] data = { topic + "-1", topic + "-2", topic + "-3"}; + suiteBase.createTopic(topic); + suiteBase.sendMessages(topic, data); + return data; + } +} diff --git a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java index 6e1abf3f385ee..208cc51b29876 100644 --- a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java +++ b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java @@ -79,9 +79,10 @@ public void testKafkaStream() throws InterruptedException { suiteBase.createTopic(topic); HashMap tmp = new HashMap(sent); - suiteBase.produceAndSendMessage(topic, + suiteBase.sendMessages(topic, JavaConverters.mapAsScalaMapConverter(tmp).asScala().toMap( - Predef.>conforms())); + Predef.>conforms()) + ); HashMap kafkaParams = new HashMap(); kafkaParams.put("zookeeper.connect", suiteBase.zkAddress()); diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala new file mode 100644 index 0000000000000..17ca9d145d665 --- /dev/null +++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala @@ -0,0 +1,306 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kafka + +import java.io.File + +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer +import scala.concurrent.duration._ +import scala.language.postfixOps + +import kafka.common.TopicAndPartition +import kafka.message.MessageAndMetadata +import kafka.serializer.StringDecoder +import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} +import org.scalatest.concurrent.Eventually + +import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.rdd.RDD +import org.apache.spark.streaming.{Milliseconds, StreamingContext, Time} +import org.apache.spark.streaming.dstream.DStream +import org.apache.spark.util.Utils + +class DirectKafkaStreamSuite extends KafkaStreamSuiteBase + with BeforeAndAfter with BeforeAndAfterAll with Eventually { + val sparkConf = new SparkConf() + .setMaster("local[4]") + .setAppName(this.getClass.getSimpleName) + + var sc: SparkContext = _ + var ssc: StreamingContext = _ + var testDir: File = _ + + override def beforeAll { + setupKafka() + } + + override def afterAll { + tearDownKafka() + } + + after { + if (ssc != null) { + ssc.stop() + sc = null + } + if (sc != null) { + sc.stop() + } + if (testDir != null) { + Utils.deleteRecursively(testDir) + } + } + + + test("basic stream receiving with multiple topics and smallest starting offset") { + val topics = Set("basic1", "basic2", "basic3") + val data = Map("a" -> 7, "b" -> 9) + topics.foreach { t => + createTopic(t) + sendMessages(t, data) + } + val totalSent = data.values.sum * topics.size + val kafkaParams = Map( + "metadata.broker.list" -> s"$brokerAddress", + "auto.offset.reset" -> "smallest" + ) + + ssc = new StreamingContext(sparkConf, Milliseconds(200)) + val stream = withClue("Error creating direct stream") { + KafkaUtils.createDirectStream[String, String, StringDecoder, StringDecoder]( + ssc, kafkaParams, topics) + } + + val allReceived = new ArrayBuffer[(String, String)] + + stream.foreachRDD { rdd => + // Get the offset ranges in the RDD + val offsets = rdd.asInstanceOf[HasOffsetRanges].offsetRanges + val collected = rdd.mapPartitionsWithIndex { (i, iter) => + // For each partition, get size of the range in the partition, + // and the number of items in the partition + val off = offsets(i) + val all = iter.toSeq + val partSize = all.size + val rangeSize = off.untilOffset - off.fromOffset + Iterator((partSize, rangeSize)) + }.collect + + // Verify whether number of elements in each partition + // matches with the corresponding offset range + collected.foreach { case (partSize, rangeSize) => + assert(partSize === rangeSize, "offset ranges are wrong") + } + } + stream.foreachRDD { rdd => allReceived ++= rdd.collect() } + ssc.start() + eventually(timeout(20000.milliseconds), interval(200.milliseconds)) { + assert(allReceived.size === totalSent, + "didn't get expected number of messages, messages:\n" + allReceived.mkString("\n")) + } + ssc.stop() + } + + test("receiving from largest starting offset") { + val topic = "largest" + val topicPartition = TopicAndPartition(topic, 0) + val data = Map("a" -> 10) + createTopic(topic) + val kafkaParams = Map( + "metadata.broker.list" -> s"$brokerAddress", + "auto.offset.reset" -> "largest" + ) + val kc = new KafkaCluster(kafkaParams) + def getLatestOffset(): Long = { + kc.getLatestLeaderOffsets(Set(topicPartition)).right.get(topicPartition).offset + } + + // Send some initial messages before starting context + sendMessages(topic, data) + eventually(timeout(10 seconds), interval(20 milliseconds)) { + assert(getLatestOffset() > 3) + } + val offsetBeforeStart = getLatestOffset() + + // Setup context and kafka stream with largest offset + ssc = new StreamingContext(sparkConf, Milliseconds(200)) + val stream = withClue("Error creating direct stream") { + KafkaUtils.createDirectStream[String, String, StringDecoder, StringDecoder]( + ssc, kafkaParams, Set(topic)) + } + assert( + stream.asInstanceOf[DirectKafkaInputDStream[_, _, _, _, _]] + .fromOffsets(topicPartition) >= offsetBeforeStart, + "Start offset not from latest" + ) + + val collectedData = new mutable.ArrayBuffer[String]() + stream.map { _._2 }.foreachRDD { rdd => collectedData ++= rdd.collect() } + ssc.start() + val newData = Map("b" -> 10) + sendMessages(topic, newData) + eventually(timeout(10 seconds), interval(50 milliseconds)) { + collectedData.contains("b") + } + assert(!collectedData.contains("a")) + } + + + test("creating stream by offset") { + val topic = "offset" + val topicPartition = TopicAndPartition(topic, 0) + val data = Map("a" -> 10) + createTopic(topic) + val kafkaParams = Map( + "metadata.broker.list" -> s"$brokerAddress", + "auto.offset.reset" -> "largest" + ) + val kc = new KafkaCluster(kafkaParams) + def getLatestOffset(): Long = { + kc.getLatestLeaderOffsets(Set(topicPartition)).right.get(topicPartition).offset + } + + // Send some initial messages before starting context + sendMessages(topic, data) + eventually(timeout(10 seconds), interval(20 milliseconds)) { + assert(getLatestOffset() >= 10) + } + val offsetBeforeStart = getLatestOffset() + + // Setup context and kafka stream with largest offset + ssc = new StreamingContext(sparkConf, Milliseconds(200)) + val stream = withClue("Error creating direct stream") { + KafkaUtils.createDirectStream[String, String, StringDecoder, StringDecoder, String]( + ssc, kafkaParams, Map(topicPartition -> 11L), + (m: MessageAndMetadata[String, String]) => m.message()) + } + assert( + stream.asInstanceOf[DirectKafkaInputDStream[_, _, _, _, _]] + .fromOffsets(topicPartition) >= offsetBeforeStart, + "Start offset not from latest" + ) + + val collectedData = new mutable.ArrayBuffer[String]() + stream.foreachRDD { rdd => collectedData ++= rdd.collect() } + ssc.start() + val newData = Map("b" -> 10) + sendMessages(topic, newData) + eventually(timeout(10 seconds), interval(50 milliseconds)) { + collectedData.contains("b") + } + assert(!collectedData.contains("a")) + } + + // Test to verify the offset ranges can be recovered from the checkpoints + test("offset recovery") { + val topic = "recovery" + createTopic(topic) + testDir = Utils.createTempDir() + + val kafkaParams = Map( + "metadata.broker.list" -> s"$brokerAddress", + "auto.offset.reset" -> "smallest" + ) + + // Send data to Kafka and wait for it to be received + def sendDataAndWaitForReceive(data: Seq[Int]) { + val strings = data.map { _.toString} + sendMessages(topic, strings.map { _ -> 1}.toMap) + eventually(timeout(10 seconds), interval(50 milliseconds)) { + assert(strings.forall { DirectKafkaStreamSuite.collectedData.contains }) + } + } + + // Setup the streaming context + ssc = new StreamingContext(sparkConf, Milliseconds(100)) + val kafkaStream = withClue("Error creating direct stream") { + KafkaUtils.createDirectStream[String, String, StringDecoder, StringDecoder]( + ssc, kafkaParams, Set(topic)) + } + val keyedStream = kafkaStream.map { v => "key" -> v._2.toInt } + val stateStream = keyedStream.updateStateByKey { (values: Seq[Int], state: Option[Int]) => + Some(values.sum + state.getOrElse(0)) + } + ssc.checkpoint(testDir.getAbsolutePath) + + // This is to collect the raw data received from Kafka + kafkaStream.foreachRDD { (rdd: RDD[(String, String)], time: Time) => + val data = rdd.map { _._2 }.collect() + DirectKafkaStreamSuite.collectedData.appendAll(data) + } + + // This is ensure all the data is eventually receiving only once + stateStream.foreachRDD { (rdd: RDD[(String, Int)]) => + rdd.collect().headOption.foreach { x => DirectKafkaStreamSuite.total = x._2 } + } + ssc.start() + + // Send some data and wait for them to be received + for (i <- (1 to 10).grouped(4)) { + sendDataAndWaitForReceive(i) + } + + // Verify that offset ranges were generated + val offsetRangesBeforeStop = getOffsetRanges(kafkaStream) + assert(offsetRangesBeforeStop.size >= 1, "No offset ranges generated") + assert( + offsetRangesBeforeStop.head._2.forall { _.fromOffset === 0 }, + "starting offset not zero" + ) + ssc.stop() + logInfo("====== RESTARTING ========") + + // Recover context from checkpoints + ssc = new StreamingContext(testDir.getAbsolutePath) + val recoveredStream = ssc.graph.getInputStreams().head.asInstanceOf[DStream[(String, String)]] + + // Verify offset ranges have been recovered + val recoveredOffsetRanges = getOffsetRanges(recoveredStream) + assert(recoveredOffsetRanges.size > 0, "No offset ranges recovered") + val earlierOffsetRangesAsSets = offsetRangesBeforeStop.map { x => (x._1, x._2.toSet) } + assert( + recoveredOffsetRanges.forall { or => + earlierOffsetRangesAsSets.contains((or._1, or._2.toSet)) + }, + "Recovered ranges are not the same as the ones generated" + ) + + // Restart context, give more data and verify the total at the end + // If the total is write that means each records has been received only once + ssc.start() + sendDataAndWaitForReceive(11 to 20) + eventually(timeout(10 seconds), interval(50 milliseconds)) { + assert(DirectKafkaStreamSuite.total === (1 to 20).sum) + } + ssc.stop() + } + + /** Get the generated offset ranges from the DirectKafkaStream */ + private def getOffsetRanges[K, V]( + kafkaStream: DStream[(K, V)]): Seq[(Time, Array[OffsetRange])] = { + kafkaStream.generatedRDDs.mapValues { rdd => + rdd.asInstanceOf[KafkaRDD[K, V, _, _, (K, V)]].offsetRanges + }.toSeq.sortBy { _._1 } + } +} + +object DirectKafkaStreamSuite { + val collectedData = new mutable.ArrayBuffer[String]() + var total = -1L +} diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaClusterSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaClusterSuite.scala new file mode 100644 index 0000000000000..fc9275b7207be --- /dev/null +++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaClusterSuite.scala @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kafka + +import scala.util.Random + +import kafka.common.TopicAndPartition +import org.scalatest.BeforeAndAfterAll + +class KafkaClusterSuite extends KafkaStreamSuiteBase with BeforeAndAfterAll { + val topic = "kcsuitetopic" + Random.nextInt(10000) + val topicAndPartition = TopicAndPartition(topic, 0) + var kc: KafkaCluster = null + + override def beforeAll() { + setupKafka() + createTopic(topic) + sendMessages(topic, Map("a" -> 1)) + kc = new KafkaCluster(Map("metadata.broker.list" -> s"$brokerAddress")) + } + + override def afterAll() { + tearDownKafka() + } + + test("metadata apis") { + val leader = kc.findLeaders(Set(topicAndPartition)).right.get(topicAndPartition) + val leaderAddress = s"${leader._1}:${leader._2}" + assert(leaderAddress === brokerAddress, "didn't get leader") + + val parts = kc.getPartitions(Set(topic)).right.get + assert(parts(topicAndPartition), "didn't get partitions") + } + + test("leader offset apis") { + val earliest = kc.getEarliestLeaderOffsets(Set(topicAndPartition)).right.get + assert(earliest(topicAndPartition).offset === 0, "didn't get earliest") + + val latest = kc.getLatestLeaderOffsets(Set(topicAndPartition)).right.get + assert(latest(topicAndPartition).offset === 1, "didn't get latest") + } + + test("consumer offset apis") { + val group = "kcsuitegroup" + Random.nextInt(10000) + + val offset = Random.nextInt(10000) + + val set = kc.setConsumerOffsets(group, Map(topicAndPartition -> offset)) + assert(set.isRight, "didn't set consumer offsets") + + val get = kc.getConsumerOffsets(group, Set(topicAndPartition)).right.get + assert(get(topicAndPartition) === offset, "didn't get consumer offsets") + } +} diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala new file mode 100644 index 0000000000000..a223da70b043f --- /dev/null +++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kafka + +import scala.util.Random + +import kafka.serializer.StringDecoder +import kafka.common.TopicAndPartition +import kafka.message.MessageAndMetadata +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark._ +import org.apache.spark.SparkContext._ + +class KafkaRDDSuite extends KafkaStreamSuiteBase with BeforeAndAfterAll { + val sparkConf = new SparkConf().setMaster("local[4]").setAppName(this.getClass.getSimpleName) + var sc: SparkContext = _ + override def beforeAll { + sc = new SparkContext(sparkConf) + + setupKafka() + } + + override def afterAll { + if (sc != null) { + sc.stop + sc = null + } + tearDownKafka() + } + + test("basic usage") { + val topic = "topicbasic" + createTopic(topic) + val messages = Set("the", "quick", "brown", "fox") + sendMessages(topic, messages.toArray) + + + val kafkaParams = Map("metadata.broker.list" -> brokerAddress, + "group.id" -> s"test-consumer-${Random.nextInt(10000)}") + + val offsetRanges = Array(OffsetRange(topic, 0, 0, messages.size)) + + val rdd = KafkaUtils.createRDD[String, String, StringDecoder, StringDecoder]( + sc, kafkaParams, offsetRanges) + + val received = rdd.map(_._2).collect.toSet + assert(received === messages) + } + + test("iterator boundary conditions") { + // the idea is to find e.g. off-by-one errors between what kafka has available and the rdd + val topic = "topic1" + val sent = Map("a" -> 5, "b" -> 3, "c" -> 10) + createTopic(topic) + + val kafkaParams = Map("metadata.broker.list" -> brokerAddress, + "group.id" -> s"test-consumer-${Random.nextInt(10000)}") + + val kc = new KafkaCluster(kafkaParams) + + // this is the "lots of messages" case + sendMessages(topic, sent) + // rdd defined from leaders after sending messages, should get the number sent + val rdd = getRdd(kc, Set(topic)) + + assert(rdd.isDefined) + assert(rdd.get.count === sent.values.sum, "didn't get all sent messages") + + val ranges = rdd.get.asInstanceOf[HasOffsetRanges] + .offsetRanges.map(o => TopicAndPartition(o.topic, o.partition) -> o.untilOffset).toMap + + kc.setConsumerOffsets(kafkaParams("group.id"), ranges) + + // this is the "0 messages" case + val rdd2 = getRdd(kc, Set(topic)) + // shouldn't get anything, since message is sent after rdd was defined + val sentOnlyOne = Map("d" -> 1) + + sendMessages(topic, sentOnlyOne) + assert(rdd2.isDefined) + assert(rdd2.get.count === 0, "got messages when there shouldn't be any") + + // this is the "exactly 1 message" case, namely the single message from sentOnlyOne above + val rdd3 = getRdd(kc, Set(topic)) + // send lots of messages after rdd was defined, they shouldn't show up + sendMessages(topic, Map("extra" -> 22)) + + assert(rdd3.isDefined) + assert(rdd3.get.count === sentOnlyOne.values.sum, "didn't get exactly one message") + + } + + // get an rdd from the committed consumer offsets until the latest leader offsets, + private def getRdd(kc: KafkaCluster, topics: Set[String]) = { + val groupId = kc.kafkaParams("group.id") + def consumerOffsets(topicPartitions: Set[TopicAndPartition]) = { + kc.getConsumerOffsets(groupId, topicPartitions).right.toOption.orElse( + kc.getEarliestLeaderOffsets(topicPartitions).right.toOption.map { offs => + offs.map(kv => kv._1 -> kv._2.offset) + } + ) + } + kc.getPartitions(topics).right.toOption.flatMap { topicPartitions => + consumerOffsets(topicPartitions).flatMap { from => + kc.getLatestLeaderOffsets(topicPartitions).right.toOption.map { until => + val offsetRanges = from.map { case (tp: TopicAndPartition, fromOffset: Long) => + OffsetRange(tp.topic, tp.partition, fromOffset, until(tp).offset) + }.toArray + + val leaders = until.map { case (tp: TopicAndPartition, lo: KafkaCluster.LeaderOffset) => + tp -> Broker(lo.host, lo.port) + }.toMap + + KafkaUtils.createRDD[String, String, StringDecoder, StringDecoder, String]( + sc, kc.kafkaParams, offsetRanges, leaders, + (mmd: MessageAndMetadata[String, String]) => s"${mmd.offset} ${mmd.message}") + } + } + } + } +} diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala index b19c053ebfc44..e4966eebb9b34 100644 --- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala +++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala @@ -26,7 +26,7 @@ import scala.concurrent.duration._ import scala.language.postfixOps import scala.util.Random -import kafka.admin.CreateTopicCommand +import kafka.admin.AdminUtils import kafka.common.{KafkaException, TopicAndPartition} import kafka.producer.{KeyedMessage, Producer, ProducerConfig} import kafka.serializer.{StringDecoder, StringEncoder} @@ -48,30 +48,41 @@ import org.apache.spark.util.Utils */ abstract class KafkaStreamSuiteBase extends FunSuite with Eventually with Logging { - var zkAddress: String = _ - var zkClient: ZkClient = _ - private val zkHost = "localhost" + private var zkPort: Int = 0 private val zkConnectionTimeout = 6000 private val zkSessionTimeout = 6000 private var zookeeper: EmbeddedZookeeper = _ - private var zkPort: Int = 0 + private val brokerHost = "localhost" private var brokerPort = 9092 private var brokerConf: KafkaConfig = _ private var server: KafkaServer = _ private var producer: Producer[String, String] = _ + private var zkReady = false + private var brokerReady = false + + protected var zkClient: ZkClient = _ + + def zkAddress: String = { + assert(zkReady, "Zookeeper not setup yet or already torn down, cannot get zookeeper address") + s"$zkHost:$zkPort" + } + + def brokerAddress: String = { + assert(brokerReady, "Kafka not setup yet or already torn down, cannot get broker address") + s"$brokerHost:$brokerPort" + } def setupKafka() { // Zookeeper server startup zookeeper = new EmbeddedZookeeper(s"$zkHost:$zkPort") // Get the actual zookeeper binding port zkPort = zookeeper.actualPort - zkAddress = s"$zkHost:$zkPort" - logInfo("==================== 0 ====================") + zkReady = true + logInfo("==================== Zookeeper Started ====================") - zkClient = new ZkClient(zkAddress, zkSessionTimeout, zkConnectionTimeout, - ZKStringSerializer) - logInfo("==================== 1 ====================") + zkClient = new ZkClient(zkAddress, zkSessionTimeout, zkConnectionTimeout, ZKStringSerializer) + logInfo("==================== Zookeeper Client Created ====================") // Kafka broker startup var bindSuccess: Boolean = false @@ -80,9 +91,8 @@ abstract class KafkaStreamSuiteBase extends FunSuite with Eventually with Loggin val brokerProps = getBrokerConfig() brokerConf = new KafkaConfig(brokerProps) server = new KafkaServer(brokerConf) - logInfo("==================== 2 ====================") server.startup() - logInfo("==================== 3 ====================") + logInfo("==================== Kafka Broker Started ====================") bindSuccess = true } catch { case e: KafkaException => @@ -94,10 +104,13 @@ abstract class KafkaStreamSuiteBase extends FunSuite with Eventually with Loggin } Thread.sleep(2000) - logInfo("==================== 4 ====================") + logInfo("==================== Kafka + Zookeeper Ready ====================") + brokerReady = true } def tearDownKafka() { + brokerReady = false + zkReady = false if (producer != null) { producer.close() producer = null @@ -121,26 +134,23 @@ abstract class KafkaStreamSuiteBase extends FunSuite with Eventually with Loggin } } - private def createTestMessage(topic: String, sent: Map[String, Int]) - : Seq[KeyedMessage[String, String]] = { - val messages = for ((s, freq) <- sent; i <- 0 until freq) yield { - new KeyedMessage[String, String](topic, s) - } - messages.toSeq - } - def createTopic(topic: String) { - CreateTopicCommand.createTopic(zkClient, topic, 1, 1, "0") - logInfo("==================== 5 ====================") + AdminUtils.createTopic(zkClient, topic, 1, 1) // wait until metadata is propagated waitUntilMetadataIsPropagated(topic, 0) + logInfo(s"==================== Topic $topic Created ====================") } - def produceAndSendMessage(topic: String, sent: Map[String, Int]) { + def sendMessages(topic: String, messageToFreq: Map[String, Int]) { + val messages = messageToFreq.flatMap { case (s, freq) => Seq.fill(freq)(s) }.toArray + sendMessages(topic, messages) + } + + def sendMessages(topic: String, messages: Array[String]) { producer = new Producer[String, String](new ProducerConfig(getProducerConfig())) - producer.send(createTestMessage(topic, sent): _*) + producer.send(messages.map { new KeyedMessage[String, String](topic, _ ) }: _*) producer.close() - logInfo("==================== 6 ====================") + logInfo(s"==================== Sent Messages: ${messages.mkString(", ")} ====================") } private def getBrokerConfig(): Properties = { @@ -164,9 +174,9 @@ abstract class KafkaStreamSuiteBase extends FunSuite with Eventually with Loggin } private def waitUntilMetadataIsPropagated(topic: String, partition: Int) { - eventually(timeout(1000 milliseconds), interval(100 milliseconds)) { + eventually(timeout(10000 milliseconds), interval(100 milliseconds)) { assert( - server.apis.leaderCache.keySet.contains(TopicAndPartition(topic, partition)), + server.apis.metadataCache.containsTopicAndPartition(topic, partition), s"Partition [$topic, $partition] metadata not propagated after timeout" ) } @@ -218,7 +228,7 @@ class KafkaStreamSuite extends KafkaStreamSuiteBase with BeforeAndAfter { val topic = "topic1" val sent = Map("a" -> 5, "b" -> 3, "c" -> 10) createTopic(topic) - produceAndSendMessage(topic, sent) + sendMessages(topic, sent) val kafkaParams = Map("zookeeper.connect" -> zkAddress, "group.id" -> s"test-consumer-${Random.nextInt(10000)}", diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/ReliableKafkaStreamSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/ReliableKafkaStreamSuite.scala index 64ccc92c81fa9..fc53c23abda85 100644 --- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/ReliableKafkaStreamSuite.scala +++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/ReliableKafkaStreamSuite.scala @@ -79,7 +79,7 @@ class ReliableKafkaStreamSuite extends KafkaStreamSuiteBase with BeforeAndAfter test("Reliable Kafka input stream with single topic") { var topic = "test-topic" createTopic(topic) - produceAndSendMessage(topic, data) + sendMessages(topic, data) // Verify whether the offset of this group/topic/partition is 0 before starting. assert(getCommitOffset(groupId, topic, 0) === None) @@ -111,7 +111,7 @@ class ReliableKafkaStreamSuite extends KafkaStreamSuiteBase with BeforeAndAfter val topics = Map("topic1" -> 1, "topic2" -> 1, "topic3" -> 1) topics.foreach { case (t, _) => createTopic(t) - produceAndSendMessage(t, data) + sendMessages(t, data) } // Before started, verify all the group/topic/partition offsets are 0. diff --git a/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTInputDStream.scala b/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTInputDStream.scala index 77661f71ada21..3c0ef94cb0fab 100644 --- a/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTInputDStream.scala +++ b/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTInputDStream.scala @@ -17,23 +17,23 @@ package org.apache.spark.streaming.mqtt +import java.io.IOException +import java.util.concurrent.Executors +import java.util.Properties + +import scala.collection.JavaConversions._ import scala.collection.Map import scala.collection.mutable.HashMap -import scala.collection.JavaConversions._ import scala.reflect.ClassTag -import java.util.Properties -import java.util.concurrent.Executors -import java.io.IOException - +import org.eclipse.paho.client.mqttv3.IMqttDeliveryToken import org.eclipse.paho.client.mqttv3.MqttCallback import org.eclipse.paho.client.mqttv3.MqttClient import org.eclipse.paho.client.mqttv3.MqttClientPersistence -import org.eclipse.paho.client.mqttv3.persist.MemoryPersistence -import org.eclipse.paho.client.mqttv3.IMqttDeliveryToken import org.eclipse.paho.client.mqttv3.MqttException import org.eclipse.paho.client.mqttv3.MqttMessage import org.eclipse.paho.client.mqttv3.MqttTopic +import org.eclipse.paho.client.mqttv3.persist.MemoryPersistence import org.apache.spark.Logging import org.apache.spark.storage.StorageLevel @@ -55,14 +55,14 @@ class MQTTInputDStream( brokerUrl: String, topic: String, storageLevel: StorageLevel - ) extends ReceiverInputDStream[String](ssc_) with Logging { - + ) extends ReceiverInputDStream[String](ssc_) { + def getReceiver(): Receiver[String] = { new MQTTReceiver(brokerUrl, topic, storageLevel) } } -private[streaming] +private[streaming] class MQTTReceiver( brokerUrl: String, topic: String, @@ -72,38 +72,40 @@ class MQTTReceiver( def onStop() { } - + def onStart() { - // Set up persistence for messages + // Set up persistence for messages val persistence = new MemoryPersistence() // Initializing Mqtt Client specifying brokerUrl, clientID and MqttClientPersistance val client = new MqttClient(brokerUrl, MqttClient.generateClientId(), persistence) - // Connect to MqttBroker - client.connect() - - // Subscribe to Mqtt topic - client.subscribe(topic) - // Callback automatically triggers as and when new message arrives on specified topic - val callback: MqttCallback = new MqttCallback() { + val callback = new MqttCallback() { // Handles Mqtt message - override def messageArrived(arg0: String, arg1: MqttMessage) { - store(new String(arg1.getPayload(),"utf-8")) + override def messageArrived(topic: String, message: MqttMessage) { + store(new String(message.getPayload(),"utf-8")) } - override def deliveryComplete(arg0: IMqttDeliveryToken) { + override def deliveryComplete(token: IMqttDeliveryToken) { } - override def connectionLost(arg0: Throwable) { - restart("Connection lost ", arg0) + override def connectionLost(cause: Throwable) { + restart("Connection lost ", cause) } } - // Set up callback for MqttClient + // Set up callback for MqttClient. This needs to happen before + // connecting or subscribing, otherwise messages may be lost client.setCallback(callback) + + // Connect to MqttBroker + client.connect() + + // Subscribe to Mqtt topic + client.subscribe(topic) + } } diff --git a/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTUtils.scala b/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTUtils.scala index c5ffe51f9986c..1142d0f56ba34 100644 --- a/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTUtils.scala +++ b/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTUtils.scala @@ -17,10 +17,11 @@ package org.apache.spark.streaming.mqtt +import scala.reflect.ClassTag + import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.StreamingContext import org.apache.spark.streaming.api.java.{JavaReceiverInputDStream, JavaStreamingContext, JavaDStream} -import scala.reflect.ClassTag import org.apache.spark.streaming.dstream.{ReceiverInputDStream, DStream} object MQTTUtils { diff --git a/external/mqtt/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java b/external/mqtt/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java index 1e24da7f5f60c..cfedb5a042a35 100644 --- a/external/mqtt/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java +++ b/external/mqtt/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java @@ -31,7 +31,7 @@ public void setUp() { SparkConf conf = new SparkConf() .setMaster("local[2]") .setAppName("test") - .set("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock"); + .set("spark.streaming.clock", "org.apache.spark.util.ManualClock"); ssc = new JavaStreamingContext(conf, new Duration(1000)); ssc.checkpoint("checkpoint"); } diff --git a/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala index fe53a29cba0c9..0f3298af6234a 100644 --- a/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala +++ b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala @@ -18,6 +18,8 @@ package org.apache.spark.streaming.mqtt import java.net.{URI, ServerSocket} +import java.util.concurrent.CountDownLatch +import java.util.concurrent.TimeUnit import scala.concurrent.duration._ import scala.language.postfixOps @@ -32,14 +34,16 @@ import org.scalatest.concurrent.Eventually import org.apache.spark.streaming.{Milliseconds, StreamingContext} import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.dstream.ReceiverInputDStream +import org.apache.spark.streaming.scheduler.StreamingListener +import org.apache.spark.streaming.scheduler.StreamingListenerReceiverStarted import org.apache.spark.SparkConf import org.apache.spark.util.Utils class MQTTStreamSuite extends FunSuite with Eventually with BeforeAndAfter { private val batchDuration = Milliseconds(500) - private val master: String = "local[2]" - private val framework: String = this.getClass.getSimpleName + private val master = "local[2]" + private val framework = this.getClass.getSimpleName private val freePort = findFreePort() private val brokerUri = "//localhost:" + freePort private val topic = "def" @@ -65,9 +69,9 @@ class MQTTStreamSuite extends FunSuite with Eventually with BeforeAndAfter { test("mqtt input stream") { val sendMessage = "MQTT demo for spark streaming" - val receiveStream: ReceiverInputDStream[String] = + val receiveStream = MQTTUtils.createStream(ssc, "tcp:" + brokerUri, topic, StorageLevel.MEMORY_ONLY) - var receiveMessage: List[String] = List() + @volatile var receiveMessage: List[String] = List() receiveStream.foreachRDD { rdd => if (rdd.collect.length > 0) { receiveMessage = receiveMessage ::: List(rdd.first) @@ -75,6 +79,11 @@ class MQTTStreamSuite extends FunSuite with Eventually with BeforeAndAfter { } } ssc.start() + + // wait for the receiver to start before publishing data, or we risk failing + // the test nondeterministically. See SPARK-4631 + waitForReceiverToStart() + publishData(sendMessage) eventually(timeout(10000 milliseconds), interval(100 milliseconds)) { assert(sendMessage.equals(receiveMessage(0))) @@ -84,6 +93,7 @@ class MQTTStreamSuite extends FunSuite with Eventually with BeforeAndAfter { private def setupMQTT() { broker = new BrokerService() + broker.setDataDirectoryFile(Utils.createTempDir()) connector = new TransportConnector() connector.setName("mqtt") connector.setUri(new URI("mqtt:" + brokerUri)) @@ -113,16 +123,22 @@ class MQTTStreamSuite extends FunSuite with Eventually with BeforeAndAfter { def publishData(data: String): Unit = { var client: MqttClient = null try { - val persistence: MqttClientPersistence = new MqttDefaultFilePersistence(persistenceDir.getAbsolutePath) + val persistence = new MqttDefaultFilePersistence(persistenceDir.getAbsolutePath) client = new MqttClient("tcp:" + brokerUri, MqttClient.generateClientId(), persistence) client.connect() if (client.isConnected) { - val msgTopic: MqttTopic = client.getTopic(topic) - val message: MqttMessage = new MqttMessage(data.getBytes("utf-8")) + val msgTopic = client.getTopic(topic) + val message = new MqttMessage(data.getBytes("utf-8")) message.setQos(1) message.setRetained(true) - for (i <- 0 to 100) { - msgTopic.publish(message) + + for (i <- 0 to 10) { + try { + msgTopic.publish(message) + } catch { + case e: MqttException if e.getReasonCode == MqttException.REASON_CODE_MAX_INFLIGHT => + Thread.sleep(50) // wait for Spark streaming to consume something from the message queue + } } } } finally { @@ -131,4 +147,18 @@ class MQTTStreamSuite extends FunSuite with Eventually with BeforeAndAfter { client = null } } + + /** + * Block until at least one receiver has started or timeout occurs. + */ + private def waitForReceiverToStart() = { + val latch = new CountDownLatch(1) + ssc.addStreamingListener(new StreamingListener { + override def onReceiverStarted(receiverStarted: StreamingListenerReceiverStarted) { + latch.countDown() + } + }) + + assert(latch.await(10, TimeUnit.SECONDS), "Timeout waiting for receiver to start.") + } } diff --git a/external/twitter/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java b/external/twitter/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java index 1e24da7f5f60c..cfedb5a042a35 100644 --- a/external/twitter/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java +++ b/external/twitter/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java @@ -31,7 +31,7 @@ public void setUp() { SparkConf conf = new SparkConf() .setMaster("local[2]") .setAppName("test") - .set("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock"); + .set("spark.streaming.clock", "org.apache.spark.util.ManualClock"); ssc = new JavaStreamingContext(conf, new Duration(1000)); ssc.checkpoint("checkpoint"); } diff --git a/external/zeromq/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java b/external/zeromq/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java index 1e24da7f5f60c..cfedb5a042a35 100644 --- a/external/zeromq/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java +++ b/external/zeromq/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java @@ -31,7 +31,7 @@ public void setUp() { SparkConf conf = new SparkConf() .setMaster("local[2]") .setAppName("test") - .set("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock"); + .set("spark.streaming.clock", "org.apache.spark.util.ManualClock"); ssc = new JavaStreamingContext(conf, new Duration(1000)); ssc.checkpoint("checkpoint"); } diff --git a/extras/kinesis-asl/pom.xml b/extras/kinesis-asl/pom.xml index c815eda52bda7..216661b8bc73a 100644 --- a/extras/kinesis-asl/pom.xml +++ b/extras/kinesis-asl/pom.xml @@ -67,11 +67,6 @@ scalacheck_${scala.binary.version} test - - org.easymock - easymockclassextension - test - com.novocode junit-interface diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointState.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointState.scala index 0b80b611cdce7..588e86a1887ec 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointState.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointState.scala @@ -18,9 +18,7 @@ package org.apache.spark.streaming.kinesis import org.apache.spark.Logging import org.apache.spark.streaming.Duration -import org.apache.spark.streaming.util.Clock -import org.apache.spark.streaming.util.ManualClock -import org.apache.spark.streaming.util.SystemClock +import org.apache.spark.util.{Clock, ManualClock, SystemClock} /** * This is a helper class for managing checkpoint clocks. @@ -35,7 +33,7 @@ private[kinesis] class KinesisCheckpointState( /* Initialize the checkpoint clock using the given currentClock + checkpointInterval millis */ val checkpointClock = new ManualClock() - checkpointClock.setTime(currentClock.currentTime() + checkpointInterval.milliseconds) + checkpointClock.setTime(currentClock.getTimeMillis() + checkpointInterval.milliseconds) /** * Check if it's time to checkpoint based on the current time and the derived time @@ -44,13 +42,13 @@ private[kinesis] class KinesisCheckpointState( * @return true if it's time to checkpoint */ def shouldCheckpoint(): Boolean = { - new SystemClock().currentTime() > checkpointClock.currentTime() + new SystemClock().getTimeMillis() > checkpointClock.getTimeMillis() } /** * Advance the checkpoint clock by the checkpoint interval. */ def advanceCheckpoint() = { - checkpointClock.addToTime(checkpointInterval.milliseconds) + checkpointClock.advance(checkpointInterval.milliseconds) } } diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala index 8ecc2d90160b1..af8cd875b4541 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala @@ -104,7 +104,7 @@ private[kinesis] class KinesisRecordProcessor( logDebug(s"Checkpoint: WorkerId $workerId completed checkpoint of ${batch.size}" + s" records for shardId $shardId") logDebug(s"Checkpoint: Next checkpoint is at " + - s" ${checkpointState.checkpointClock.currentTime()} for shardId $shardId") + s" ${checkpointState.checkpointClock.getTimeMillis()} for shardId $shardId") } } catch { case e: Throwable => { diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala index 41dbd64c2b1fa..255fe65819608 100644 --- a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala +++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala @@ -20,17 +20,17 @@ import java.nio.ByteBuffer import scala.collection.JavaConversions.seqAsJavaList -import org.apache.spark.annotation.Experimental import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.Milliseconds import org.apache.spark.streaming.Seconds import org.apache.spark.streaming.StreamingContext import org.apache.spark.streaming.TestSuiteBase -import org.apache.spark.streaming.util.Clock -import org.apache.spark.streaming.util.ManualClock +import org.apache.spark.util.{ManualClock, Clock} + +import org.mockito.Mockito._ import org.scalatest.BeforeAndAfter import org.scalatest.Matchers -import org.scalatest.mock.EasyMockSugar +import org.scalatest.mock.MockitoSugar import com.amazonaws.services.kinesis.clientlibrary.exceptions.InvalidStateException import com.amazonaws.services.kinesis.clientlibrary.exceptions.KinesisClientLibDependencyException @@ -42,10 +42,10 @@ import com.amazonaws.services.kinesis.clientlibrary.types.ShutdownReason import com.amazonaws.services.kinesis.model.Record /** - * Suite of Kinesis streaming receiver tests focusing mostly on the KinesisRecordProcessor + * Suite of Kinesis streaming receiver tests focusing mostly on the KinesisRecordProcessor */ class KinesisReceiverSuite extends TestSuiteBase with Matchers with BeforeAndAfter - with EasyMockSugar { + with MockitoSugar { val app = "TestKinesisReceiver" val stream = "mySparkStream" @@ -73,6 +73,14 @@ class KinesisReceiverSuite extends TestSuiteBase with Matchers with BeforeAndAft currentClockMock = mock[Clock] } + override def afterFunction(): Unit = { + super.afterFunction() + // Since this suite was originally written using EasyMock, add this to preserve the old + // mocking semantics (see SPARK-5735 for more details) + verifyNoMoreInteractions(receiverMock, checkpointerMock, checkpointClockMock, + checkpointStateMock, currentClockMock) + } + test("kinesis utils api") { val ssc = new StreamingContext(master, framework, batchDuration) // Tests the API, does not actually test data receiving @@ -83,193 +91,175 @@ class KinesisReceiverSuite extends TestSuiteBase with Matchers with BeforeAndAft } test("process records including store and checkpoint") { - val expectedCheckpointIntervalMillis = 10 - expecting { - receiverMock.isStopped().andReturn(false).once() - receiverMock.store(record1.getData().array()).once() - receiverMock.store(record2.getData().array()).once() - checkpointStateMock.shouldCheckpoint().andReturn(true).once() - checkpointerMock.checkpoint().once() - checkpointStateMock.advanceCheckpoint().once() - } - whenExecuting(receiverMock, checkpointerMock, checkpointStateMock) { - val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId, - checkpointStateMock) - recordProcessor.processRecords(batch, checkpointerMock) - } + when(receiverMock.isStopped()).thenReturn(false) + when(checkpointStateMock.shouldCheckpoint()).thenReturn(true) + + val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId, checkpointStateMock) + recordProcessor.processRecords(batch, checkpointerMock) + + verify(receiverMock, times(1)).isStopped() + verify(receiverMock, times(1)).store(record1.getData().array()) + verify(receiverMock, times(1)).store(record2.getData().array()) + verify(checkpointStateMock, times(1)).shouldCheckpoint() + verify(checkpointerMock, times(1)).checkpoint() + verify(checkpointStateMock, times(1)).advanceCheckpoint() } test("shouldn't store and checkpoint when receiver is stopped") { - expecting { - receiverMock.isStopped().andReturn(true).once() - } - whenExecuting(receiverMock, checkpointerMock, checkpointStateMock) { - val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId, - checkpointStateMock) - recordProcessor.processRecords(batch, checkpointerMock) - } + when(receiverMock.isStopped()).thenReturn(true) + + val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId, checkpointStateMock) + recordProcessor.processRecords(batch, checkpointerMock) + + verify(receiverMock, times(1)).isStopped() } test("shouldn't checkpoint when exception occurs during store") { - expecting { - receiverMock.isStopped().andReturn(false).once() - receiverMock.store(record1.getData().array()).andThrow(new RuntimeException()).once() - } - whenExecuting(receiverMock, checkpointerMock, checkpointStateMock) { - intercept[RuntimeException] { - val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId, - checkpointStateMock) - recordProcessor.processRecords(batch, checkpointerMock) - } + when(receiverMock.isStopped()).thenReturn(false) + when(receiverMock.store(record1.getData().array())).thenThrow(new RuntimeException()) + + intercept[RuntimeException] { + val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId, checkpointStateMock) + recordProcessor.processRecords(batch, checkpointerMock) } + + verify(receiverMock, times(1)).isStopped() + verify(receiverMock, times(1)).store(record1.getData().array()) } test("should set checkpoint time to currentTime + checkpoint interval upon instantiation") { - expecting { - currentClockMock.currentTime().andReturn(0).once() - } - whenExecuting(currentClockMock) { + when(currentClockMock.getTimeMillis()).thenReturn(0) + val checkpointIntervalMillis = 10 - val checkpointState = new KinesisCheckpointState(Milliseconds(checkpointIntervalMillis), currentClockMock) - assert(checkpointState.checkpointClock.currentTime() == checkpointIntervalMillis) - } + val checkpointState = + new KinesisCheckpointState(Milliseconds(checkpointIntervalMillis), currentClockMock) + assert(checkpointState.checkpointClock.getTimeMillis() == checkpointIntervalMillis) + + verify(currentClockMock, times(1)).getTimeMillis() } test("should checkpoint if we have exceeded the checkpoint interval") { - expecting { - currentClockMock.currentTime().andReturn(0).once() - } - whenExecuting(currentClockMock) { - val checkpointState = new KinesisCheckpointState(Milliseconds(Long.MinValue), currentClockMock) - assert(checkpointState.shouldCheckpoint()) - } + when(currentClockMock.getTimeMillis()).thenReturn(0) + + val checkpointState = new KinesisCheckpointState(Milliseconds(Long.MinValue), currentClockMock) + assert(checkpointState.shouldCheckpoint()) + + verify(currentClockMock, times(1)).getTimeMillis() } test("shouldn't checkpoint if we have not exceeded the checkpoint interval") { - expecting { - currentClockMock.currentTime().andReturn(0).once() - } - whenExecuting(currentClockMock) { - val checkpointState = new KinesisCheckpointState(Milliseconds(Long.MaxValue), currentClockMock) - assert(!checkpointState.shouldCheckpoint()) - } + when(currentClockMock.getTimeMillis()).thenReturn(0) + + val checkpointState = new KinesisCheckpointState(Milliseconds(Long.MaxValue), currentClockMock) + assert(!checkpointState.shouldCheckpoint()) + + verify(currentClockMock, times(1)).getTimeMillis() } test("should add to time when advancing checkpoint") { - expecting { - currentClockMock.currentTime().andReturn(0).once() - } - whenExecuting(currentClockMock) { - val checkpointIntervalMillis = 10 - val checkpointState = new KinesisCheckpointState(Milliseconds(checkpointIntervalMillis), currentClockMock) - assert(checkpointState.checkpointClock.currentTime() == checkpointIntervalMillis) - checkpointState.advanceCheckpoint() - assert(checkpointState.checkpointClock.currentTime() == (2 * checkpointIntervalMillis)) - } + when(currentClockMock.getTimeMillis()).thenReturn(0) + + val checkpointIntervalMillis = 10 + val checkpointState = + new KinesisCheckpointState(Milliseconds(checkpointIntervalMillis), currentClockMock) + assert(checkpointState.checkpointClock.getTimeMillis() == checkpointIntervalMillis) + checkpointState.advanceCheckpoint() + assert(checkpointState.checkpointClock.getTimeMillis() == (2 * checkpointIntervalMillis)) + + verify(currentClockMock, times(1)).getTimeMillis() } test("shutdown should checkpoint if the reason is TERMINATE") { - expecting { - checkpointerMock.checkpoint().once() - } - whenExecuting(checkpointerMock, checkpointStateMock) { - val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId, - checkpointStateMock) - val reason = ShutdownReason.TERMINATE - recordProcessor.shutdown(checkpointerMock, reason) - } + val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId, checkpointStateMock) + val reason = ShutdownReason.TERMINATE + recordProcessor.shutdown(checkpointerMock, reason) + + verify(checkpointerMock, times(1)).checkpoint() } test("shutdown should not checkpoint if the reason is something other than TERMINATE") { - expecting { - } - whenExecuting(checkpointerMock, checkpointStateMock) { - val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId, - checkpointStateMock) - recordProcessor.shutdown(checkpointerMock, ShutdownReason.ZOMBIE) - recordProcessor.shutdown(checkpointerMock, null) - } + val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId, checkpointStateMock) + recordProcessor.shutdown(checkpointerMock, ShutdownReason.ZOMBIE) + recordProcessor.shutdown(checkpointerMock, null) + + verify(checkpointerMock, never()).checkpoint() } test("retry success on first attempt") { val expectedIsStopped = false - expecting { - receiverMock.isStopped().andReturn(expectedIsStopped).once() - } - whenExecuting(receiverMock) { - val actualVal = KinesisRecordProcessor.retryRandom(receiverMock.isStopped(), 2, 100) - assert(actualVal == expectedIsStopped) - } + when(receiverMock.isStopped()).thenReturn(expectedIsStopped) + + val actualVal = KinesisRecordProcessor.retryRandom(receiverMock.isStopped(), 2, 100) + assert(actualVal == expectedIsStopped) + + verify(receiverMock, times(1)).isStopped() } test("retry success on second attempt after a Kinesis throttling exception") { val expectedIsStopped = false - expecting { - receiverMock.isStopped().andThrow(new ThrottlingException("error message")) - .andReturn(expectedIsStopped).once() - } - whenExecuting(receiverMock) { - val actualVal = KinesisRecordProcessor.retryRandom(receiverMock.isStopped(), 2, 100) - assert(actualVal == expectedIsStopped) - } + when(receiverMock.isStopped()) + .thenThrow(new ThrottlingException("error message")) + .thenReturn(expectedIsStopped) + + val actualVal = KinesisRecordProcessor.retryRandom(receiverMock.isStopped(), 2, 100) + assert(actualVal == expectedIsStopped) + + verify(receiverMock, times(2)).isStopped() } test("retry success on second attempt after a Kinesis dependency exception") { val expectedIsStopped = false - expecting { - receiverMock.isStopped().andThrow(new KinesisClientLibDependencyException("error message")) - .andReturn(expectedIsStopped).once() - } - whenExecuting(receiverMock) { - val actualVal = KinesisRecordProcessor.retryRandom(receiverMock.isStopped(), 2, 100) - assert(actualVal == expectedIsStopped) - } + when(receiverMock.isStopped()) + .thenThrow(new KinesisClientLibDependencyException("error message")) + .thenReturn(expectedIsStopped) + + val actualVal = KinesisRecordProcessor.retryRandom(receiverMock.isStopped(), 2, 100) + assert(actualVal == expectedIsStopped) + + verify(receiverMock, times(2)).isStopped() } test("retry failed after a shutdown exception") { - expecting { - checkpointerMock.checkpoint().andThrow(new ShutdownException("error message")).once() - } - whenExecuting(checkpointerMock) { - intercept[ShutdownException] { - KinesisRecordProcessor.retryRandom(checkpointerMock.checkpoint(), 2, 100) - } + when(checkpointerMock.checkpoint()).thenThrow(new ShutdownException("error message")) + + intercept[ShutdownException] { + KinesisRecordProcessor.retryRandom(checkpointerMock.checkpoint(), 2, 100) } + + verify(checkpointerMock, times(1)).checkpoint() } test("retry failed after an invalid state exception") { - expecting { - checkpointerMock.checkpoint().andThrow(new InvalidStateException("error message")).once() - } - whenExecuting(checkpointerMock) { - intercept[InvalidStateException] { - KinesisRecordProcessor.retryRandom(checkpointerMock.checkpoint(), 2, 100) - } + when(checkpointerMock.checkpoint()).thenThrow(new InvalidStateException("error message")) + + intercept[InvalidStateException] { + KinesisRecordProcessor.retryRandom(checkpointerMock.checkpoint(), 2, 100) } + + verify(checkpointerMock, times(1)).checkpoint() } test("retry failed after unexpected exception") { - expecting { - checkpointerMock.checkpoint().andThrow(new RuntimeException("error message")).once() - } - whenExecuting(checkpointerMock) { - intercept[RuntimeException] { - KinesisRecordProcessor.retryRandom(checkpointerMock.checkpoint(), 2, 100) - } + when(checkpointerMock.checkpoint()).thenThrow(new RuntimeException("error message")) + + intercept[RuntimeException] { + KinesisRecordProcessor.retryRandom(checkpointerMock.checkpoint(), 2, 100) } + + verify(checkpointerMock, times(1)).checkpoint() } test("retry failed after exhausing all retries") { val expectedErrorMessage = "final try error message" - expecting { - checkpointerMock.checkpoint().andThrow(new ThrottlingException("error message")) - .andThrow(new ThrottlingException(expectedErrorMessage)).once() - } - whenExecuting(checkpointerMock) { - val exception = intercept[RuntimeException] { - KinesisRecordProcessor.retryRandom(checkpointerMock.checkpoint(), 2, 100) - } - exception.getMessage().shouldBe(expectedErrorMessage) + when(checkpointerMock.checkpoint()) + .thenThrow(new ThrottlingException("error message")) + .thenThrow(new ThrottlingException(expectedErrorMessage)) + + val exception = intercept[RuntimeException] { + KinesisRecordProcessor.retryRandom(checkpointerMock.checkpoint(), 2, 100) } + exception.getMessage().shouldBe(expectedErrorMessage) + + verify(checkpointerMock, times(2)).checkpoint() } } diff --git a/extras/spark-ganglia-lgpl/pom.xml b/extras/spark-ganglia-lgpl/pom.xml index d1427f6a0c6e9..f2f0aa78b0a4b 100644 --- a/extras/spark-ganglia-lgpl/pom.xml +++ b/extras/spark-ganglia-lgpl/pom.xml @@ -42,7 +42,7 @@ - com.codahale.metrics + io.dropwizard.metrics metrics-ganglia diff --git a/graphx/pom.xml b/graphx/pom.xml index 72374aae6da9b..8fac24b6ed86d 100644 --- a/graphx/pom.xml +++ b/graphx/pom.xml @@ -40,6 +40,10 @@ spark-core_${scala.binary.version} ${project.version} + + com.google.guava + guava + org.jblas jblas diff --git a/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala b/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala index 84b72b390ca35..8494d06b1cdb7 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala @@ -55,7 +55,7 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab * @return an RDD containing the edges in this graph * * @see [[Edge]] for the edge type. - * @see [[triplets]] to get an RDD which contains all the edges + * @see [[Graph#triplets]] to get an RDD which contains all the edges * along with their vertex data. * */ @@ -104,6 +104,18 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab */ def checkpoint(): Unit + /** + * Return whether this Graph has been checkpointed or not. + * This returns true iff both the vertices RDD and edges RDD have been checkpointed. + */ + def isCheckpointed: Boolean + + /** + * Gets the name of the files to which this Graph was checkpointed. + * (The vertices RDD and edges RDD are checkpointed separately.) + */ + def getCheckpointFiles: Seq[String] + /** * Uncaches both vertices and edges of this graph. This is useful in iterative algorithms that * build a new graph in each iteration. diff --git a/graphx/src/main/scala/org/apache/spark/graphx/GraphLoader.scala b/graphx/src/main/scala/org/apache/spark/graphx/GraphLoader.scala index 4933aecba1286..21187be7678a6 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/GraphLoader.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/GraphLoader.scala @@ -77,7 +77,7 @@ object GraphLoader extends Logging { if (!line.isEmpty && line(0) != '#') { val lineArray = line.split("\\s+") if (lineArray.length < 2) { - logWarning("Invalid line: " + line) + throw new IllegalArgumentException("Invalid line: " + line) } val srcId = lineArray(0).toLong val dstId = lineArray(1).toLong diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeRDDImpl.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeRDDImpl.scala index 897c7ee12a436..56cb41661e300 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeRDDImpl.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeRDDImpl.scala @@ -19,7 +19,7 @@ package org.apache.spark.graphx.impl import scala.reflect.{classTag, ClassTag} -import org.apache.spark.{OneToOneDependency, Partition, Partitioner, TaskContext} +import org.apache.spark.{OneToOneDependency, HashPartitioner, TaskContext} import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel @@ -46,7 +46,7 @@ class EdgeRDDImpl[ED: ClassTag, VD: ClassTag] private[graphx] ( * partitioner that allows co-partitioning with `partitionsRDD`. */ override val partitioner = - partitionsRDD.partitioner.orElse(Some(Partitioner.defaultPartitioner(partitionsRDD))) + partitionsRDD.partitioner.orElse(Some(new HashPartitioner(partitions.size))) override def collect(): Array[Edge[ED]] = this.map(_.copy()).collect() @@ -70,10 +70,20 @@ class EdgeRDDImpl[ED: ClassTag, VD: ClassTag] private[graphx] ( this } + override def getStorageLevel = partitionsRDD.getStorageLevel + override def checkpoint() = { partitionsRDD.checkpoint() } - + + override def isCheckpointed: Boolean = { + firstParent[(PartitionID, EdgePartition[ED, VD])].isCheckpointed + } + + override def getCheckpointFile: Option[String] = { + partitionsRDD.getCheckpointFile + } + /** The number of edges in the RDD. */ override def count(): Long = { partitionsRDD.map(_._2.size.toLong).reduce(_ + _) diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala index 3f4a900d5b601..90a74d23a26cc 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala @@ -70,6 +70,17 @@ class GraphImpl[VD: ClassTag, ED: ClassTag] protected ( replicatedVertexView.edges.checkpoint() } + override def isCheckpointed: Boolean = { + vertices.isCheckpointed && replicatedVertexView.edges.isCheckpointed + } + + override def getCheckpointFiles: Seq[String] = { + Seq(vertices.getCheckpointFile, replicatedVertexView.edges.getCheckpointFile).flatMap { + case Some(path) => Seq(path) + case None => Seq() + } + } + override def unpersist(blocking: Boolean = true): Graph[VD, ED] = { unpersistVertices(blocking) replicatedVertexView.edges.unpersist(blocking) diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexRDDImpl.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexRDDImpl.scala index 9732c5b00c6d9..904be213147dc 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexRDDImpl.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexRDDImpl.scala @@ -71,10 +71,20 @@ class VertexRDDImpl[VD] private[graphx] ( this } + override def getStorageLevel = partitionsRDD.getStorageLevel + override def checkpoint() = { partitionsRDD.checkpoint() } - + + override def isCheckpointed: Boolean = { + firstParent[ShippableVertexPartition[VD]].isCheckpointed + } + + override def getCheckpointFile: Option[String] = { + partitionsRDD.getCheckpointFile + } + /** The number of vertices in the RDD. */ override def count(): Long = { partitionsRDD.map(_.size).reduce(_ + _) @@ -94,8 +104,14 @@ class VertexRDDImpl[VD] private[graphx] ( this.mapVertexPartitions(_.map(f)) override def diff(other: VertexRDD[VD]): VertexRDD[VD] = { + val otherPartition = other match { + case other: VertexRDD[_] if this.partitioner == other.partitioner => + other.partitionsRDD + case _ => + VertexRDD(other.partitionBy(this.partitioner.get)).partitionsRDD + } val newPartitionsRDD = partitionsRDD.zipPartitions( - other.partitionsRDD, preservesPartitioning = true + otherPartition, preservesPartitioning = true ) { (thisIter, otherIter) => val thisPart = thisIter.next() val otherPart = otherIter.next() @@ -123,7 +139,7 @@ class VertexRDDImpl[VD] private[graphx] ( // Test if the other vertex is a VertexRDD to choose the optimal join strategy. // If the other set is a VertexRDD then we use the much more efficient leftZipJoin other match { - case other: VertexRDD[_] => + case other: VertexRDD[_] if this.partitioner == other.partitioner => leftZipJoin(other)(f) case _ => this.withPartitionsRDD[VD3]( @@ -152,7 +168,7 @@ class VertexRDDImpl[VD] private[graphx] ( // Test if the other vertex is a VertexRDD to choose the optimal join strategy. // If the other set is a VertexRDD then we use the much more efficient innerZipJoin other match { - case other: VertexRDD[_] => + case other: VertexRDD[_] if this.partitioner == other.partitioner => innerZipJoin(other)(f) case _ => this.withPartitionsRDD( diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala index f58587e10a820..3e4157a63fd1c 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala @@ -37,6 +37,17 @@ object SVDPlusPlus { var gamma7: Double) extends Serializable + /** + * This method is now replaced by the updated version of `run()` and returns exactly + * the same result. + */ + @deprecated("Call run()", "1.4.0") + def runSVDPlusPlus(edges: RDD[Edge[Double]], conf: Conf) + : (Graph[(Array[Double], Array[Double], Double, Double), Double], Double) = + { + run(edges, conf) + } + /** * Implement SVD++ based on "Factorization Meets the Neighborhood: * a Multifaceted Collaborative Filtering Model", @@ -52,7 +63,7 @@ object SVDPlusPlus { * @return a graph with vertex attributes containing the trained model */ def run(edges: RDD[Edge[Double]], conf: Conf) - : (Graph[(DoubleMatrix, DoubleMatrix, Double, Double), Double], Double) = + : (Graph[(Array[Double], Array[Double], Double, Double), Double], Double) = { // Generate default vertex attribute def defaultF(rank: Int): (DoubleMatrix, DoubleMatrix, Double, Double) = { @@ -72,17 +83,22 @@ object SVDPlusPlus { // construct graph var g = Graph.fromEdges(edges, defaultF(conf.rank)).cache() + materialize(g) + edges.unpersist() // Calculate initial bias and norm val t0 = g.aggregateMessages[(Long, Double)]( ctx => { ctx.sendToSrc((1L, ctx.attr)); ctx.sendToDst((1L, ctx.attr)) }, (g1, g2) => (g1._1 + g2._1, g1._2 + g2._2)) - g = g.outerJoinVertices(t0) { + val gJoinT0 = g.outerJoinVertices(t0) { (vid: VertexId, vd: (DoubleMatrix, DoubleMatrix, Double, Double), msg: Option[(Long, Double)]) => (vd._1, vd._2, msg.get._2 / msg.get._1, 1.0 / scala.math.sqrt(msg.get._1)) - } + }.cache() + materialize(gJoinT0) + g.unpersist() + g = gJoinT0 def sendMsgTrainF(conf: Conf, u: Double) (ctx: EdgeContext[ @@ -114,12 +130,15 @@ object SVDPlusPlus { val t1 = g.aggregateMessages[DoubleMatrix]( ctx => ctx.sendToSrc(ctx.dstAttr._2), (g1, g2) => g1.addColumnVector(g2)) - g = g.outerJoinVertices(t1) { + val gJoinT1 = g.outerJoinVertices(t1) { (vid: VertexId, vd: (DoubleMatrix, DoubleMatrix, Double, Double), msg: Option[DoubleMatrix]) => if (msg.isDefined) (vd._1, vd._1 .addColumnVector(msg.get.mul(vd._4)), vd._3, vd._4) else vd - } + }.cache() + materialize(gJoinT1) + g.unpersist() + g = gJoinT1 // Phase 2, update p for user nodes and q, y for item nodes g.cache() @@ -127,13 +146,16 @@ object SVDPlusPlus { sendMsgTrainF(conf, u), (g1: (DoubleMatrix, DoubleMatrix, Double), g2: (DoubleMatrix, DoubleMatrix, Double)) => (g1._1.addColumnVector(g2._1), g1._2.addColumnVector(g2._2), g1._3 + g2._3)) - g = g.outerJoinVertices(t2) { + val gJoinT2 = g.outerJoinVertices(t2) { (vid: VertexId, vd: (DoubleMatrix, DoubleMatrix, Double, Double), msg: Option[(DoubleMatrix, DoubleMatrix, Double)]) => (vd._1.addColumnVector(msg.get._1), vd._2.addColumnVector(msg.get._2), vd._3 + msg.get._3, vd._4) - } + }.cache() + materialize(gJoinT2) + g.unpersist() + g = gJoinT2 } // calculate error on training set @@ -147,13 +169,28 @@ object SVDPlusPlus { val err = (ctx.attr - pred) * (ctx.attr - pred) ctx.sendToDst(err) } + g.cache() val t3 = g.aggregateMessages[Double](sendMsgTestF(conf, u), _ + _) - g = g.outerJoinVertices(t3) { + val gJoinT3 = g.outerJoinVertices(t3) { (vid: VertexId, vd: (DoubleMatrix, DoubleMatrix, Double, Double), msg: Option[Double]) => if (msg.isDefined) (vd._1, vd._2, vd._3, msg.get) else vd - } + }.cache() + materialize(gJoinT3) + g.unpersist() + g = gJoinT3 - (g, u) + // Convert DoubleMatrix to Array[Double]: + val newVertices = g.vertices.mapValues(v => (v._1.toArray, v._2.toArray, v._3, v._4)) + (Graph(newVertices, g.edges), u) } + + /** + * Forces materialization of a Graph by count()ing its RDDs. + */ + private def materialize(g: Graph[_,_]): Unit = { + g.vertices.count() + g.edges.count() + } + } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/ShortestPaths.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/ShortestPaths.scala index 590f0474957dd..179f2843818e0 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/ShortestPaths.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/ShortestPaths.scala @@ -61,8 +61,8 @@ object ShortestPaths { } def sendMessage(edge: EdgeTriplet[SPMap, _]): Iterator[(VertexId, SPMap)] = { - val newAttr = incrementMap(edge.srcAttr) - if (edge.dstAttr != addMaps(newAttr, edge.dstAttr)) Iterator((edge.dstId, newAttr)) + val newAttr = incrementMap(edge.dstAttr) + if (edge.srcAttr != addMaps(newAttr, edge.srcAttr)) Iterator((edge.srcId, newAttr)) else Iterator.empty } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala b/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala index 8a13c74221546..2d6a825b61726 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala @@ -133,6 +133,12 @@ object GraphGenerators { // This ensures that the 4 quadrants are the same size at all recursion levels val numVertices = math.round( math.pow(2.0, math.ceil(math.log(requestedNumVertices) / math.log(2.0)))).toInt + val numEdgesUpperBound = + math.pow(2.0, 2 * ((math.log(numVertices) / math.log(2.0)) - 1)).toInt + if (numEdgesUpperBound < numEdges) { + throw new IllegalArgumentException( + s"numEdges must be <= $numEdgesUpperBound but was $numEdges") + } var edges: Set[Edge[Int]] = Set() while (edges.size < numEdges) { if (edges.size % 100 == 0) { diff --git a/graphx/src/test/scala/org/apache/spark/graphx/EdgeRDDSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/EdgeRDDSuite.scala new file mode 100644 index 0000000000000..eb1dbe52c2fda --- /dev/null +++ b/graphx/src/test/scala/org/apache/spark/graphx/EdgeRDDSuite.scala @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.graphx + +import org.scalatest.FunSuite + +import org.apache.spark.storage.StorageLevel + +class EdgeRDDSuite extends FunSuite with LocalSparkContext { + + test("cache, getStorageLevel") { + // test to see if getStorageLevel returns correct value after caching + withSpark { sc => + val verts = sc.parallelize(List((0L, 0), (1L, 1), (1L, 2), (2L, 3), (2L, 3), (2L, 3))) + val edges = EdgeRDD.fromEdges(sc.parallelize(List.empty[Edge[Int]])) + assert(edges.getStorageLevel == StorageLevel.NONE) + edges.cache() + assert(edges.getStorageLevel == StorageLevel.MEMORY_ONLY) + } + } + +} diff --git a/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala index 9da0064104fb6..b61d9f0fbe5e4 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala @@ -25,6 +25,7 @@ import org.apache.spark.SparkContext import org.apache.spark.graphx.Graph._ import org.apache.spark.graphx.PartitionStrategy._ import org.apache.spark.rdd._ +import org.apache.spark.storage.StorageLevel class GraphSuite extends FunSuite with LocalSparkContext { @@ -375,6 +376,8 @@ class GraphSuite extends FunSuite with LocalSparkContext { val ring = (0L to 100L).zip((1L to 99L) :+ 0L).map { case (a, b) => Edge(a, b, 1)} val rdd = sc.parallelize(ring) val graph = Graph.fromEdges(rdd, 1.0F) + assert(!graph.isCheckpointed) + assert(graph.getCheckpointFiles.size === 0) graph.checkpoint() graph.edges.map(_.attr).count() graph.vertices.map(_._2).count() @@ -383,6 +386,42 @@ class GraphSuite extends FunSuite with LocalSparkContext { val verticesDependencies = graph.vertices.partitionsRDD.dependencies assert(edgesDependencies.forall(_.rdd.isInstanceOf[CheckpointRDD[_]])) assert(verticesDependencies.forall(_.rdd.isInstanceOf[CheckpointRDD[_]])) + assert(graph.isCheckpointed) + assert(graph.getCheckpointFiles.size === 2) + } + } + + test("cache, getStorageLevel") { + // test to see if getStorageLevel returns correct value + withSpark { sc => + val verts = sc.parallelize(List((1: VertexId, "a"), (2: VertexId, "b")), 1) + val edges = sc.parallelize(List(Edge(1, 2, 0), Edge(2, 1, 0)), 2) + val graph = Graph(verts, edges, "", StorageLevel.MEMORY_ONLY, StorageLevel.MEMORY_ONLY) + // Note: Before caching, graph.vertices is cached, but graph.edges is not (but graph.edges' + // parent RDD is cached). + graph.cache() + assert(graph.vertices.getStorageLevel == StorageLevel.MEMORY_ONLY) + assert(graph.edges.getStorageLevel == StorageLevel.MEMORY_ONLY) + } + } + + test("non-default number of edge partitions") { + val n = 10 + val defaultParallelism = 3 + val numEdgePartitions = 4 + assert(defaultParallelism != numEdgePartitions) + val conf = new org.apache.spark.SparkConf() + .set("spark.default.parallelism", defaultParallelism.toString) + val sc = new SparkContext("local", "test", conf) + try { + val edges = sc.parallelize((1 to n).map(x => (x: VertexId, 0: VertexId)), + numEdgePartitions) + val graph = Graph.fromEdgeTuples(edges, 1) + val neighborAttrSums = graph.mapReduceTriplets[Int]( + et => Iterator((et.dstId, et.srcAttr)), _ + _) + assert(neighborAttrSums.collect.toSet === Set((0: VertexId, n))) + } finally { + sc.stop() } } diff --git a/graphx/src/test/scala/org/apache/spark/graphx/VertexRDDSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/VertexRDDSuite.scala index 42d3f21dbae98..131959cea3ef7 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/VertexRDDSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/VertexRDDSuite.scala @@ -17,12 +17,11 @@ package org.apache.spark.graphx -import org.apache.spark.SparkContext -import org.apache.spark.graphx.Graph._ -import org.apache.spark.graphx.impl.EdgePartition -import org.apache.spark.rdd._ import org.scalatest.FunSuite +import org.apache.spark.SparkContext +import org.apache.spark.storage.StorageLevel + class VertexRDDSuite extends FunSuite with LocalSparkContext { def vertices(sc: SparkContext, n: Int) = { @@ -110,4 +109,16 @@ class VertexRDDSuite extends FunSuite with LocalSparkContext { } } + test("cache, getStorageLevel") { + // test to see if getStorageLevel returns correct value after caching + withSpark { sc => + val verts = sc.parallelize(List((0L, 0), (1L, 1), (1L, 2), (2L, 3), (2L, 3), (2L, 3))) + val edges = EdgeRDD.fromEdges(sc.parallelize(List.empty[Edge[Int]])) + val rdd = VertexRDD(verts, edges, 0, (a: Int, b: Int) => a + b) + assert(rdd.getStorageLevel == StorageLevel.NONE) + rdd.cache() + assert(rdd.getStorageLevel == StorageLevel.MEMORY_ONLY) + } + } + } diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/SVDPlusPlusSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/SVDPlusPlusSuite.scala index e01df56e94de9..9987a4b1a3c25 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/lib/SVDPlusPlusSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/SVDPlusPlusSuite.scala @@ -32,7 +32,7 @@ class SVDPlusPlusSuite extends FunSuite with LocalSparkContext { Edge(fields(0).toLong * 2, fields(1).toLong * 2 + 1, fields(2).toDouble) } val conf = new SVDPlusPlus.Conf(10, 2, 0.0, 5.0, 0.007, 0.007, 0.005, 0.015) // 2 iterations - var (graph, u) = SVDPlusPlus.run(edges, conf) + var (graph, u) = SVDPlusPlus.runSVDPlusPlus(edges, conf) graph.cache() val err = graph.vertices.collect().map{ case (vid, vd) => if (vid % 2 == 1) vd._4 else 0.0 diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/ShortestPathsSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/ShortestPathsSuite.scala index 265827b3341c2..f2c38e79c452c 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/lib/ShortestPathsSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/ShortestPathsSuite.scala @@ -40,7 +40,7 @@ class ShortestPathsSuite extends FunSuite with LocalSparkContext { val graph = Graph.fromEdgeTuples(edges, 1) val landmarks = Seq(1, 4).map(_.toLong) val results = ShortestPaths.run(graph, landmarks).vertices.collect.map { - case (v, spMap) => (v, spMap.mapValues(_.get)) + case (v, spMap) => (v, spMap.mapValues(i => i)) } assert(results.toSet === shortestPaths) } diff --git a/graphx/src/test/scala/org/apache/spark/graphx/util/GraphGeneratorsSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/util/GraphGeneratorsSuite.scala index 3abefbe52fa8a..8d9c8ddccbb3c 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/util/GraphGeneratorsSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/util/GraphGeneratorsSuite.scala @@ -110,4 +110,14 @@ class GraphGeneratorsSuite extends FunSuite with LocalSparkContext { } } + test("SPARK-5064 GraphGenerators.rmatGraph numEdges upper bound") { + withSpark { sc => + val g1 = GraphGenerators.rmatGraph(sc, 4, 4) + assert(g1.edges.count() === 4) + intercept[IllegalArgumentException] { + val g2 = GraphGenerators.rmatGraph(sc, 4, 8) + } + } + } + } diff --git a/make-distribution.sh b/make-distribution.sh index 4e2f400be3053..dd990d4b96e46 100755 --- a/make-distribution.sh +++ b/make-distribution.sh @@ -32,6 +32,10 @@ SPARK_HOME="$(cd "`dirname "$0"`"; pwd)" DISTDIR="$SPARK_HOME/dist" SPARK_TACHYON=false +TACHYON_VERSION="0.5.0" +TACHYON_TGZ="tachyon-${TACHYON_VERSION}-bin.tar.gz" +TACHYON_URL="https://github.com/amplab/tachyon/releases/download/v${TACHYON_VERSION}/${TACHYON_TGZ}" + MAKE_TGZ=false NAME=none MVN="$SPARK_HOME/build/mvn" @@ -93,10 +97,10 @@ done if [ -z "$JAVA_HOME" ]; then # Fall back on JAVA_HOME from rpm, if found - if which rpm &>/dev/null; then - RPM_JAVA_HOME=$(rpm -E %java_home 2>/dev/null) + if [ $(command -v rpm) ]; then + RPM_JAVA_HOME="$(rpm -E %java_home 2>/dev/null)" if [ "$RPM_JAVA_HOME" != "%java_home" ]; then - JAVA_HOME=$RPM_JAVA_HOME + JAVA_HOME="$RPM_JAVA_HOME" echo "No JAVA_HOME set, proceeding with '$JAVA_HOME' learned from rpm" fi fi @@ -107,25 +111,26 @@ if [ -z "$JAVA_HOME" ]; then exit -1 fi -if which git &>/dev/null; then +if [ $(command -v git) ]; then GITREV=$(git rev-parse --short HEAD 2>/dev/null || :) - if [ ! -z $GITREV ]; then + if [ ! -z "$GITREV" ]; then GITREVSTRING=" (git revision $GITREV)" fi unset GITREV fi -if ! which $MVN &>/dev/null; then + +if [ ! $(command -v "$MVN") ] ; then echo -e "Could not locate Maven command: '$MVN'." echo -e "Specify the Maven command with the --mvn flag" exit -1; fi -VERSION=$(mvn help:evaluate -Dexpression=project.version 2>/dev/null | grep -v "INFO" | tail -n 1) -SPARK_HADOOP_VERSION=$(mvn help:evaluate -Dexpression=hadoop.version $@ 2>/dev/null\ +VERSION=$("$MVN" help:evaluate -Dexpression=project.version 2>/dev/null | grep -v "INFO" | tail -n 1) +SPARK_HADOOP_VERSION=$("$MVN" help:evaluate -Dexpression=hadoop.version $@ 2>/dev/null\ | grep -v "INFO"\ | tail -n 1) -SPARK_HIVE=$($MVN help:evaluate -Dexpression=project.activeProfiles -pl sql/hive $@ 2>/dev/null\ +SPARK_HIVE=$("$MVN" help:evaluate -Dexpression=project.activeProfiles -pl sql/hive $@ 2>/dev/null\ | grep -v "INFO"\ | fgrep --count "hive";\ # Reset exit status to 0, otherwise the script stops here if the last grep finds nothing\ @@ -142,7 +147,7 @@ if [[ ! "$JAVA_VERSION" =~ "1.6" && -z "$SKIP_JAVA_TEST" ]]; then echo "Output from 'java -version' was:" echo "$JAVA_VERSION" read -p "Would you like to continue anyways? [y,n]: " -r - if [[ ! $REPLY =~ ^[Yy]$ ]]; then + if [[ ! "$REPLY" =~ ^[Yy]$ ]]; then echo "Okay, exiting." exit 1 fi @@ -171,13 +176,16 @@ cd "$SPARK_HOME" export MAVEN_OPTS="-Xmx2g -XX:MaxPermSize=512M -XX:ReservedCodeCacheSize=512m" -BUILD_COMMAND="$MVN clean package -DskipTests $@" +# Store the command as an array because $MVN variable might have spaces in it. +# Normal quoting tricks don't work. +# See: http://mywiki.wooledge.org/BashFAQ/050 +BUILD_COMMAND=("$MVN" clean package -DskipTests $@) # Actually build the jar echo -e "\nBuilding with..." -echo -e "\$ $BUILD_COMMAND\n" +echo -e "\$ ${BUILD_COMMAND[@]}\n" -${BUILD_COMMAND} +"${BUILD_COMMAND[@]}" # Make directories rm -rf "$DISTDIR" @@ -222,16 +230,22 @@ cp -r "$SPARK_HOME/ec2" "$DISTDIR" # Download and copy in tachyon, if requested if [ "$SPARK_TACHYON" == "true" ]; then - TACHYON_VERSION="0.5.0" - TACHYON_URL="https://github.com/amplab/tachyon/releases/download/v${TACHYON_VERSION}/tachyon-${TACHYON_VERSION}-bin.tar.gz" - TMPD=`mktemp -d 2>/dev/null || mktemp -d -t 'disttmp'` - pushd $TMPD > /dev/null + pushd "$TMPD" > /dev/null echo "Fetching tachyon tgz" - wget "$TACHYON_URL" - tar xf "tachyon-${TACHYON_VERSION}-bin.tar.gz" + TACHYON_DL="${TACHYON_TGZ}.part" + if [ $(command -v curl) ]; then + curl --silent -k -L "${TACHYON_URL}" > "${TACHYON_DL}" && mv "${TACHYON_DL}" "${TACHYON_TGZ}" + elif [ $(command -v wget) ]; then + wget --quiet "${TACHYON_URL}" -O "${TACHYON_DL}" && mv "${TACHYON_DL}" "${TACHYON_TGZ}" + else + printf "You do not have curl or wget installed. please install Tachyon manually.\n" + exit -1 + fi + + tar xzf "${TACHYON_TGZ}" cp "tachyon-${TACHYON_VERSION}/core/target/tachyon-${TACHYON_VERSION}-jar-with-dependencies.jar" "$DISTDIR/lib" mkdir -p "$DISTDIR/tachyon/src/main/java/tachyon/web" cp -r "tachyon-${TACHYON_VERSION}"/{bin,conf,libexec} "$DISTDIR/tachyon" @@ -245,7 +259,7 @@ if [ "$SPARK_TACHYON" == "true" ]; then fi popd > /dev/null - rm -rf $TMPD + rm -rf "$TMPD" fi if [ "$MAKE_TGZ" == "true" ]; then diff --git a/mllib/pom.xml b/mllib/pom.xml index a0bda89ccaa71..a8cee3d51a780 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -50,6 +50,11 @@ spark-sql_${scala.binary.version} ${project.version} + + org.apache.spark + spark-graphx_${scala.binary.version} + ${project.version} + org.jblas jblas @@ -125,6 +130,9 @@ ../python pyspark/mllib/*.py + pyspark/mllib/stat/*.py + pyspark/ml/*.py + pyspark/ml/param/*.py diff --git a/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala b/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala index fdbee743e8177..eff7ef925dfbd 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala @@ -18,12 +18,10 @@ package org.apache.spark.ml import scala.annotation.varargs -import scala.collection.JavaConverters._ import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml.param.{ParamMap, ParamPair, Params} -import org.apache.spark.sql.SchemaRDD -import org.apache.spark.sql.api.java.JavaSchemaRDD +import org.apache.spark.sql.DataFrame /** * :: AlphaComponent :: @@ -36,11 +34,12 @@ abstract class Estimator[M <: Model[M]] extends PipelineStage with Params { * Fits a single model to the input data with optional parameters. * * @param dataset input dataset - * @param paramPairs optional list of param pairs (overwrite embedded params) + * @param paramPairs Optional list of param pairs. + * These values override any specified in this Estimator's embedded ParamMap. * @return fitted model */ @varargs - def fit(dataset: SchemaRDD, paramPairs: ParamPair[_]*): M = { + def fit(dataset: DataFrame, paramPairs: ParamPair[_]*): M = { val map = new ParamMap().put(paramPairs: _*) fit(dataset, map) } @@ -49,10 +48,11 @@ abstract class Estimator[M <: Model[M]] extends PipelineStage with Params { * Fits a single model to the input data with provided parameter map. * * @param dataset input dataset - * @param paramMap parameter map + * @param paramMap Parameter map. + * These values override any specified in this Estimator's embedded ParamMap. * @return fitted model */ - def fit(dataset: SchemaRDD, paramMap: ParamMap): M + def fit(dataset: DataFrame, paramMap: ParamMap): M /** * Fits multiple models to the input data with multiple sets of parameters. @@ -60,46 +60,11 @@ abstract class Estimator[M <: Model[M]] extends PipelineStage with Params { * Subclasses could overwrite this to optimize multi-model training. * * @param dataset input dataset - * @param paramMaps an array of parameter maps + * @param paramMaps An array of parameter maps. + * These values override any specified in this Estimator's embedded ParamMap. * @return fitted models, matching the input parameter maps */ - def fit(dataset: SchemaRDD, paramMaps: Array[ParamMap]): Seq[M] = { + def fit(dataset: DataFrame, paramMaps: Array[ParamMap]): Seq[M] = { paramMaps.map(fit(dataset, _)) } - - // Java-friendly versions of fit. - - /** - * Fits a single model to the input data with optional parameters. - * - * @param dataset input dataset - * @param paramPairs optional list of param pairs (overwrite embedded params) - * @return fitted model - */ - @varargs - def fit(dataset: JavaSchemaRDD, paramPairs: ParamPair[_]*): M = { - fit(dataset.schemaRDD, paramPairs: _*) - } - - /** - * Fits a single model to the input data with provided parameter map. - * - * @param dataset input dataset - * @param paramMap parameter map - * @return fitted model - */ - def fit(dataset: JavaSchemaRDD, paramMap: ParamMap): M = { - fit(dataset.schemaRDD, paramMap) - } - - /** - * Fits multiple models to the input data with multiple sets of parameters. - * - * @param dataset input dataset - * @param paramMaps an array of parameter maps - * @return fitted models, matching the input parameter maps - */ - def fit(dataset: JavaSchemaRDD, paramMaps: Array[ParamMap]): java.util.List[M] = { - fit(dataset.schemaRDD, paramMaps).asJava - } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/Evaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/Evaluator.scala index db563dd550e56..d2ca2e6871e6b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Evaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Evaluator.scala @@ -19,7 +19,7 @@ package org.apache.spark.ml import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml.param.ParamMap -import org.apache.spark.sql.SchemaRDD +import org.apache.spark.sql.DataFrame /** * :: AlphaComponent :: @@ -35,5 +35,5 @@ abstract class Evaluator extends Identifiable { * @param paramMap parameter map that specifies the input columns and output metrics * @return metric */ - def evaluate(dataset: SchemaRDD, paramMap: ParamMap): Double + def evaluate(dataset: DataFrame, paramMap: ParamMap): Double } diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala index ad6fed178fae9..5bbcd2e080e07 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala @@ -20,9 +20,9 @@ package org.apache.spark.ml import scala.collection.mutable.ListBuffer import org.apache.spark.Logging -import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.annotation.{AlphaComponent, DeveloperApi} import org.apache.spark.ml.param.{Param, ParamMap} -import org.apache.spark.sql.SchemaRDD +import org.apache.spark.sql.DataFrame import org.apache.spark.sql.types.StructType /** @@ -33,9 +33,17 @@ import org.apache.spark.sql.types.StructType abstract class PipelineStage extends Serializable with Logging { /** + * :: DeveloperAPI :: + * * Derives the output schema from the input schema and parameters. + * The schema describes the columns and types of the data. + * + * @param schema Input schema to this stage + * @param paramMap Parameters passed to this stage + * @return Output schema from this stage */ - private[ml] def transformSchema(schema: StructType, paramMap: ParamMap): StructType + @DeveloperApi + def transformSchema(schema: StructType, paramMap: ParamMap): StructType /** * Derives the output schema from the input schema and parameters, optionally with logging. @@ -58,11 +66,11 @@ abstract class PipelineStage extends Serializable with Logging { /** * :: AlphaComponent :: * A simple pipeline, which acts as an estimator. A Pipeline consists of a sequence of stages, each - * of which is either an [[Estimator]] or a [[Transformer]]. When [[Pipeline.fit]] is called, the - * stages are executed in order. If a stage is an [[Estimator]], its [[Estimator.fit]] method will + * of which is either an [[Estimator]] or a [[Transformer]]. When [[Pipeline#fit]] is called, the + * stages are executed in order. If a stage is an [[Estimator]], its [[Estimator#fit]] method will * be called on the input dataset to fit a model. Then the model, which is a transformer, will be * used to transform the dataset as the input to the next stage. If a stage is a [[Transformer]], - * its [[Transformer.transform]] method will be called to produce the dataset for the next stage. + * its [[Transformer#transform]] method will be called to produce the dataset for the next stage. * The fitted model from a [[Pipeline]] is an [[PipelineModel]], which consists of fitted models and * transformers, corresponding to the pipeline stages. If there are no stages, the pipeline acts as * an identity transformer. @@ -77,9 +85,9 @@ class Pipeline extends Estimator[PipelineModel] { /** * Fits the pipeline to the input dataset with additional parameters. If a stage is an - * [[Estimator]], its [[Estimator.fit]] method will be called on the input dataset to fit a model. + * [[Estimator]], its [[Estimator#fit]] method will be called on the input dataset to fit a model. * Then the model, which is a transformer, will be used to transform the dataset as the input to - * the next stage. If a stage is a [[Transformer]], its [[Transformer.transform]] method will be + * the next stage. If a stage is a [[Transformer]], its [[Transformer#transform]] method will be * called to produce the dataset for the next stage. The fitted model from a [[Pipeline]] is an * [[PipelineModel]], which consists of fitted models and transformers, corresponding to the * pipeline stages. If there are no stages, the output model acts as an identity transformer. @@ -88,7 +96,7 @@ class Pipeline extends Estimator[PipelineModel] { * @param paramMap parameter map * @return fitted pipeline */ - override def fit(dataset: SchemaRDD, paramMap: ParamMap): PipelineModel = { + override def fit(dataset: DataFrame, paramMap: ParamMap): PipelineModel = { transformSchema(dataset.schema, paramMap, logging = true) val map = this.paramMap ++ paramMap val theStages = map(stages) @@ -114,7 +122,9 @@ class Pipeline extends Estimator[PipelineModel] { throw new IllegalArgumentException( s"Do not support stage $stage of type ${stage.getClass}") } - curDataset = transformer.transform(curDataset, paramMap) + if (index < indexOfLastEstimator) { + curDataset = transformer.transform(curDataset, paramMap) + } transformers += transformer } else { transformers += stage.asInstanceOf[Transformer] @@ -124,7 +134,7 @@ class Pipeline extends Estimator[PipelineModel] { new PipelineModel(this, map, transformers.toArray) } - private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { + override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { val map = this.paramMap ++ paramMap val theStages = map(stages) require(theStages.toSet.size == theStages.size, @@ -162,14 +172,14 @@ class PipelineModel private[ml] ( } } - override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = { + override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { // Precedence of ParamMaps: paramMap > this.paramMap > fittingParamMap val map = (fittingParamMap ++ this.paramMap) ++ paramMap transformSchema(dataset.schema, map, logging = true) stages.foldLeft(dataset)((cur, transformer) => transformer.transform(cur, map)) } - private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { + override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { // Precedence of ParamMaps: paramMap > this.paramMap > fittingParamMap val map = (fittingParamMap ++ this.paramMap) ++ paramMap stages.foldLeft(schema)((cur, transformer) => transformer.transformSchema(cur, map)) diff --git a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala index 1331b9124045c..9a5848684b179 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala @@ -22,10 +22,8 @@ import scala.annotation.varargs import org.apache.spark.Logging import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml.param._ -import org.apache.spark.sql.SchemaRDD -import org.apache.spark.sql.api.java.JavaSchemaRDD -import org.apache.spark.sql.catalyst.analysis.Star -import org.apache.spark.sql.catalyst.expressions.ScalaUdf +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ /** @@ -42,7 +40,7 @@ abstract class Transformer extends PipelineStage with Params { * @return transformed dataset */ @varargs - def transform(dataset: SchemaRDD, paramPairs: ParamPair[_]*): SchemaRDD = { + def transform(dataset: DataFrame, paramPairs: ParamPair[_]*): DataFrame = { val map = new ParamMap() paramPairs.foreach(map.put(_)) transform(dataset, map) @@ -54,30 +52,7 @@ abstract class Transformer extends PipelineStage with Params { * @param paramMap additional parameters, overwrite embedded params * @return transformed dataset */ - def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD - - // Java-friendly versions of transform. - - /** - * Transforms the dataset with optional parameters. - * @param dataset input datset - * @param paramPairs optional list of param pairs, overwrite embedded params - * @return transformed dataset - */ - @varargs - def transform(dataset: JavaSchemaRDD, paramPairs: ParamPair[_]*): JavaSchemaRDD = { - transform(dataset.schemaRDD, paramPairs: _*).toJavaSchemaRDD - } - - /** - * Transforms the dataset with provided parameter map as additional parameters. - * @param dataset input dataset - * @param paramMap additional parameters, overwrite embedded params - * @return transformed dataset - */ - def transform(dataset: JavaSchemaRDD, paramMap: ParamMap): JavaSchemaRDD = { - transform(dataset.schemaRDD, paramMap).toJavaSchemaRDD - } + def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame } /** @@ -87,7 +62,10 @@ abstract class Transformer extends PipelineStage with Params { private[ml] abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, OUT, T]] extends Transformer with HasInputCol with HasOutputCol with Logging { + /** @group setParam */ def setInputCol(value: String): T = set(inputCol, value).asInstanceOf[T] + + /** @group setParam */ def setOutputCol(value: String): T = set(outputCol, value).asInstanceOf[T] /** @@ -119,11 +97,10 @@ private[ml] abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, O StructType(outputFields) } - override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = { + override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { transformSchema(dataset.schema, paramMap, logging = true) - import dataset.sqlContext._ val map = this.paramMap ++ paramMap - val udf = ScalaUdf(this.createTransformFunc(map), outputDataType, Seq(map(inputCol).attr)) - dataset.select(Star(None), udf as map(outputCol)) + dataset.withColumn(map(outputCol), + callUDF(this.createTransformFunc(map), outputDataType, dataset(map(inputCol)))) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala new file mode 100644 index 0000000000000..c5fc89f935432 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala @@ -0,0 +1,206 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.classification + +import org.apache.spark.annotation.{DeveloperApi, AlphaComponent} +import org.apache.spark.ml.impl.estimator.{PredictionModel, Predictor, PredictorParams} +import org.apache.spark.ml.param.{Params, ParamMap, HasRawPredictionCol} +import org.apache.spark.mllib.linalg.{Vector, VectorUDT} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.types.{DataType, DoubleType, StructType} + + +/** + * :: DeveloperApi :: + * Params for classification. + * + * NOTE: This is currently private[spark] but will be made public later once it is stabilized. + */ +@DeveloperApi +private[spark] trait ClassifierParams extends PredictorParams + with HasRawPredictionCol { + + override protected def validateAndTransformSchema( + schema: StructType, + paramMap: ParamMap, + fitting: Boolean, + featuresDataType: DataType): StructType = { + val parentSchema = super.validateAndTransformSchema(schema, paramMap, fitting, featuresDataType) + val map = this.paramMap ++ paramMap + addOutputColumn(parentSchema, map(rawPredictionCol), new VectorUDT) + } +} + +/** + * :: AlphaComponent :: + * Single-label binary or multiclass classification. + * Classes are indexed {0, 1, ..., numClasses - 1}. + * + * @tparam FeaturesType Type of input features. E.g., [[Vector]] + * @tparam E Concrete Estimator type + * @tparam M Concrete Model type + * + * NOTE: This is currently private[spark] but will be made public later once it is stabilized. + */ +@AlphaComponent +private[spark] abstract class Classifier[ + FeaturesType, + E <: Classifier[FeaturesType, E, M], + M <: ClassificationModel[FeaturesType, M]] + extends Predictor[FeaturesType, E, M] + with ClassifierParams { + + /** @group setParam */ + def setRawPredictionCol(value: String): E = + set(rawPredictionCol, value).asInstanceOf[E] + + // TODO: defaultEvaluator (follow-up PR) +} + +/** + * :: AlphaComponent :: + * Model produced by a [[Classifier]]. + * Classes are indexed {0, 1, ..., numClasses - 1}. + * + * @tparam FeaturesType Type of input features. E.g., [[Vector]] + * @tparam M Concrete Model type + * + * NOTE: This is currently private[spark] but will be made public later once it is stabilized. + */ +@AlphaComponent +private[spark] +abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[FeaturesType, M]] + extends PredictionModel[FeaturesType, M] with ClassifierParams { + + /** @group setParam */ + def setRawPredictionCol(value: String): M = set(rawPredictionCol, value).asInstanceOf[M] + + /** Number of classes (values which the label can take). */ + def numClasses: Int + + /** + * Transforms dataset by reading from [[featuresCol]], and appending new columns as specified by + * parameters: + * - predicted labels as [[predictionCol]] of type [[Double]] + * - raw predictions (confidences) as [[rawPredictionCol]] of type [[Vector]]. + * + * @param dataset input dataset + * @param paramMap additional parameters, overwrite embedded params + * @return transformed dataset + */ + override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { + // This default implementation should be overridden as needed. + + // Check schema + transformSchema(dataset.schema, paramMap, logging = true) + val map = this.paramMap ++ paramMap + + // Prepare model + val tmpModel = if (paramMap.size != 0) { + val tmpModel = this.copy() + Params.inheritValues(paramMap, parent, tmpModel) + tmpModel + } else { + this + } + + val (numColsOutput, outputData) = + ClassificationModel.transformColumnsImpl[FeaturesType](dataset, tmpModel, map) + if (numColsOutput == 0) { + logWarning(s"$uid: ClassificationModel.transform() was called as NOOP" + + " since no output columns were set.") + } + outputData + } + + /** + * :: DeveloperApi :: + * + * Predict label for the given features. + * This internal method is used to implement [[transform()]] and output [[predictionCol]]. + * + * This default implementation for classification predicts the index of the maximum value + * from [[predictRaw()]]. + */ + @DeveloperApi + override protected def predict(features: FeaturesType): Double = { + predictRaw(features).toArray.zipWithIndex.maxBy(_._1)._2 + } + + /** + * :: DeveloperApi :: + * + * Raw prediction for each possible label. + * The meaning of a "raw" prediction may vary between algorithms, but it intuitively gives + * a measure of confidence in each possible label (where larger = more confident). + * This internal method is used to implement [[transform()]] and output [[rawPredictionCol]]. + * + * @return vector where element i is the raw prediction for label i. + * This raw prediction may be any real number, where a larger value indicates greater + * confidence for that label. + */ + @DeveloperApi + protected def predictRaw(features: FeaturesType): Vector + +} + +private[ml] object ClassificationModel { + + /** + * Added prediction column(s). This is separated from [[ClassificationModel.transform()]] + * since it is used by [[org.apache.spark.ml.classification.ProbabilisticClassificationModel]]. + * @param dataset Input dataset + * @param map Parameter map. This will NOT be merged with the embedded paramMap; the merge + * should already be done. + * @return (number of columns added, transformed dataset) + */ + def transformColumnsImpl[FeaturesType]( + dataset: DataFrame, + model: ClassificationModel[FeaturesType, _], + map: ParamMap): (Int, DataFrame) = { + + // Output selected columns only. + // This is a bit complicated since it tries to avoid repeated computation. + var tmpData = dataset + var numColsOutput = 0 + if (map(model.rawPredictionCol) != "") { + // output raw prediction + val features2raw: FeaturesType => Vector = model.predictRaw + tmpData = tmpData.withColumn(map(model.rawPredictionCol), + callUDF(features2raw, new VectorUDT, col(map(model.featuresCol)))) + numColsOutput += 1 + if (map(model.predictionCol) != "") { + val raw2pred: Vector => Double = (rawPred) => { + rawPred.toArray.zipWithIndex.maxBy(_._1)._2 + } + tmpData = tmpData.withColumn(map(model.predictionCol), + callUDF(raw2pred, DoubleType, col(map(model.rawPredictionCol)))) + numColsOutput += 1 + } + } else if (map(model.predictionCol) != "") { + // output prediction + val features2pred: FeaturesType => Double = model.predict + tmpData = tmpData.withColumn(map(model.predictionCol), + callUDF(features2pred, DoubleType, col(map(model.featuresCol)))) + numColsOutput += 1 + } + (numColsOutput, tmpData) + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 8c570812f8316..21f61d80dd95a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -18,132 +18,185 @@ package org.apache.spark.ml.classification import org.apache.spark.annotation.AlphaComponent -import org.apache.spark.ml._ import org.apache.spark.ml.param._ import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS -import org.apache.spark.mllib.linalg.{BLAS, Vector, VectorUDT} -import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.analysis.Star -import org.apache.spark.sql.catalyst.dsl._ -import org.apache.spark.sql.types.{DoubleType, StructField, StructType} +import org.apache.spark.mllib.linalg.{VectorUDT, BLAS, Vector, Vectors} +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types.DoubleType import org.apache.spark.storage.StorageLevel + /** - * :: AlphaComponent :: * Params for logistic regression. */ -@AlphaComponent -private[classification] trait LogisticRegressionParams extends Params - with HasRegParam with HasMaxIter with HasLabelCol with HasThreshold with HasFeaturesCol - with HasScoreCol with HasPredictionCol { +private[classification] trait LogisticRegressionParams extends ProbabilisticClassifierParams + with HasRegParam with HasMaxIter with HasThreshold - /** - * Validates and transforms the input schema with the provided param map. - * @param schema input schema - * @param paramMap additional parameters - * @param fitting whether this is in fitting - * @return output schema - */ - protected def validateAndTransformSchema( - schema: StructType, - paramMap: ParamMap, - fitting: Boolean): StructType = { - val map = this.paramMap ++ paramMap - val featuresType = schema(map(featuresCol)).dataType - // TODO: Support casting Array[Double] and Array[Float] to Vector. - require(featuresType.isInstanceOf[VectorUDT], - s"Features column ${map(featuresCol)} must be a vector column but got $featuresType.") - if (fitting) { - val labelType = schema(map(labelCol)).dataType - require(labelType == DoubleType, - s"Cannot convert label column ${map(labelCol)} of type $labelType to a double column.") - } - val fieldNames = schema.fieldNames - require(!fieldNames.contains(map(scoreCol)), s"Score column ${map(scoreCol)} already exists.") - require(!fieldNames.contains(map(predictionCol)), - s"Prediction column ${map(predictionCol)} already exists.") - val outputFields = schema.fields ++ Seq( - StructField(map(scoreCol), DoubleType, false), - StructField(map(predictionCol), DoubleType, false)) - StructType(outputFields) - } -} /** + * :: AlphaComponent :: + * * Logistic regression. + * Currently, this class only supports binary classification. */ -class LogisticRegression extends Estimator[LogisticRegressionModel] with LogisticRegressionParams { +@AlphaComponent +class LogisticRegression + extends ProbabilisticClassifier[Vector, LogisticRegression, LogisticRegressionModel] + with LogisticRegressionParams { setRegParam(0.1) setMaxIter(100) setThreshold(0.5) + /** @group setParam */ def setRegParam(value: Double): this.type = set(regParam, value) + + /** @group setParam */ def setMaxIter(value: Int): this.type = set(maxIter, value) - def setLabelCol(value: String): this.type = set(labelCol, value) + + /** @group setParam */ def setThreshold(value: Double): this.type = set(threshold, value) - def setFeaturesCol(value: String): this.type = set(featuresCol, value) - def setScoreCol(value: String): this.type = set(scoreCol, value) - def setPredictionCol(value: String): this.type = set(predictionCol, value) - override def fit(dataset: SchemaRDD, paramMap: ParamMap): LogisticRegressionModel = { - transformSchema(dataset.schema, paramMap, logging = true) - import dataset.sqlContext._ - val map = this.paramMap ++ paramMap - val instances = dataset.select(map(labelCol).attr, map(featuresCol).attr) - .map { case Row(label: Double, features: Vector) => - LabeledPoint(label, features) - }.persist(StorageLevel.MEMORY_AND_DISK) + override protected def train(dataset: DataFrame, paramMap: ParamMap): LogisticRegressionModel = { + // Extract columns from data. If dataset is persisted, do not persist oldDataset. + val oldDataset = extractLabeledPoints(dataset, paramMap) + val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE + if (handlePersistence) { + oldDataset.persist(StorageLevel.MEMORY_AND_DISK) + } + + // Train model val lr = new LogisticRegressionWithLBFGS lr.optimizer - .setRegParam(map(regParam)) - .setNumIterations(map(maxIter)) - val lrm = new LogisticRegressionModel(this, map, lr.run(instances).weights) - instances.unpersist() - // copy model params - Params.inheritValues(map, this, lrm) - lrm - } + .setRegParam(paramMap(regParam)) + .setNumIterations(paramMap(maxIter)) + val oldModel = lr.run(oldDataset) + val lrm = new LogisticRegressionModel(this, paramMap, oldModel.weights, oldModel.intercept) - private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { - validateAndTransformSchema(schema, paramMap, fitting = true) + if (handlePersistence) { + oldDataset.unpersist() + } + lrm } } + /** * :: AlphaComponent :: + * * Model produced by [[LogisticRegression]]. */ @AlphaComponent class LogisticRegressionModel private[ml] ( override val parent: LogisticRegression, override val fittingParamMap: ParamMap, - weights: Vector) - extends Model[LogisticRegressionModel] with LogisticRegressionParams { + val weights: Vector, + val intercept: Double) + extends ProbabilisticClassificationModel[Vector, LogisticRegressionModel] + with LogisticRegressionParams { + + setThreshold(0.5) + /** @group setParam */ def setThreshold(value: Double): this.type = set(threshold, value) - def setFeaturesCol(value: String): this.type = set(featuresCol, value) - def setScoreCol(value: String): this.type = set(scoreCol, value) - def setPredictionCol(value: String): this.type = set(predictionCol, value) - private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { - validateAndTransformSchema(schema, paramMap, fitting = false) + private val margin: Vector => Double = (features) => { + BLAS.dot(features, weights) + intercept + } + + private val score: Vector => Double = (features) => { + val m = margin(features) + 1.0 / (1.0 + math.exp(-m)) } - override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = { + override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { + // This is overridden (a) to be more efficient (avoiding re-computing values when creating + // multiple output columns) and (b) to handle threshold, which the abstractions do not use. + // TODO: We should abstract away the steps defined by UDFs below so that the abstractions + // can call whichever UDFs are needed to create the output columns. + + // Check schema transformSchema(dataset.schema, paramMap, logging = true) - import dataset.sqlContext._ + val map = this.paramMap ++ paramMap - val score: Vector => Double = (v) => { - val margin = BLAS.dot(v, weights) - 1.0 / (1.0 + math.exp(-margin)) + + // Output selected columns only. + // This is a bit complicated since it tries to avoid repeated computation. + // rawPrediction (-margin, margin) + // probability (1.0-score, score) + // prediction (max margin) + var tmpData = dataset + var numColsOutput = 0 + if (map(rawPredictionCol) != "") { + val features2raw: Vector => Vector = (features) => predictRaw(features) + tmpData = tmpData.withColumn(map(rawPredictionCol), + callUDF(features2raw, new VectorUDT, col(map(featuresCol)))) + numColsOutput += 1 + } + if (map(probabilityCol) != "") { + if (map(rawPredictionCol) != "") { + val raw2prob = udf { (rawPreds: Vector) => + val prob1 = 1.0 / (1.0 + math.exp(-rawPreds(1))) + Vectors.dense(1.0 - prob1, prob1): Vector + } + tmpData = tmpData.withColumn(map(probabilityCol), raw2prob(col(map(rawPredictionCol)))) + } else { + val features2prob = udf { (features: Vector) => predictProbabilities(features) : Vector } + tmpData = tmpData.withColumn(map(probabilityCol), features2prob(col(map(featuresCol)))) + } + numColsOutput += 1 + } + if (map(predictionCol) != "") { + val t = map(threshold) + if (map(probabilityCol) != "") { + val predict = udf { probs: Vector => + if (probs(1) > t) 1.0 else 0.0 + } + tmpData = tmpData.withColumn(map(predictionCol), predict(col(map(probabilityCol)))) + } else if (map(rawPredictionCol) != "") { + val predict = udf { rawPreds: Vector => + val prob1 = 1.0 / (1.0 + math.exp(-rawPreds(1))) + if (prob1 > t) 1.0 else 0.0 + } + tmpData = tmpData.withColumn(map(predictionCol), predict(col(map(rawPredictionCol)))) + } else { + val predict = udf { features: Vector => this.predict(features) } + tmpData = tmpData.withColumn(map(predictionCol), predict(col(map(featuresCol)))) + } + numColsOutput += 1 } - val t = map(threshold) - val predict: Double => Double = (score) => { - if (score > t) 1.0 else 0.0 + if (numColsOutput == 0) { + this.logWarning(s"$uid: LogisticRegressionModel.transform() was called as NOOP" + + " since no output columns were set.") } - dataset.select(Star(None), score.call(map(featuresCol).attr) as map(scoreCol)) - .select(Star(None), predict.call(map(scoreCol).attr) as map(predictionCol)) + tmpData + } + + override val numClasses: Int = 2 + + /** + * Predict label for the given feature vector. + * The behavior of this can be adjusted using [[threshold]]. + */ + override protected def predict(features: Vector): Double = { + println(s"LR.predict with threshold: ${paramMap(threshold)}") + if (score(features) > paramMap(threshold)) 1 else 0 + } + + override protected def predictProbabilities(features: Vector): Vector = { + val s = score(features) + Vectors.dense(1.0 - s, s) + } + + override protected def predictRaw(features: Vector): Vector = { + val m = margin(features) + Vectors.dense(0.0, m) + } + + override protected def copy(): LogisticRegressionModel = { + val m = new LogisticRegressionModel(parent, fittingParamMap, weights, intercept) + Params.inheritValues(this.paramMap, this, m) + m } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala new file mode 100644 index 0000000000000..bd8caac855981 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala @@ -0,0 +1,149 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.classification + +import org.apache.spark.annotation.{AlphaComponent, DeveloperApi} +import org.apache.spark.ml.param.{HasProbabilityCol, ParamMap, Params} +import org.apache.spark.mllib.linalg.{Vector, VectorUDT} +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types.{DataType, StructType} + + +/** + * Params for probabilistic classification. + */ +private[classification] trait ProbabilisticClassifierParams + extends ClassifierParams with HasProbabilityCol { + + override protected def validateAndTransformSchema( + schema: StructType, + paramMap: ParamMap, + fitting: Boolean, + featuresDataType: DataType): StructType = { + val parentSchema = super.validateAndTransformSchema(schema, paramMap, fitting, featuresDataType) + val map = this.paramMap ++ paramMap + addOutputColumn(parentSchema, map(probabilityCol), new VectorUDT) + } +} + + +/** + * :: AlphaComponent :: + * + * Single-label binary or multiclass classifier which can output class conditional probabilities. + * + * @tparam FeaturesType Type of input features. E.g., [[Vector]] + * @tparam E Concrete Estimator type + * @tparam M Concrete Model type + * + * NOTE: This is currently private[spark] but will be made public later once it is stabilized. + */ +@AlphaComponent +private[spark] abstract class ProbabilisticClassifier[ + FeaturesType, + E <: ProbabilisticClassifier[FeaturesType, E, M], + M <: ProbabilisticClassificationModel[FeaturesType, M]] + extends Classifier[FeaturesType, E, M] with ProbabilisticClassifierParams { + + /** @group setParam */ + def setProbabilityCol(value: String): E = set(probabilityCol, value).asInstanceOf[E] +} + + +/** + * :: AlphaComponent :: + * + * Model produced by a [[ProbabilisticClassifier]]. + * Classes are indexed {0, 1, ..., numClasses - 1}. + * + * @tparam FeaturesType Type of input features. E.g., [[Vector]] + * @tparam M Concrete Model type + * + * NOTE: This is currently private[spark] but will be made public later once it is stabilized. + */ +@AlphaComponent +private[spark] abstract class ProbabilisticClassificationModel[ + FeaturesType, + M <: ProbabilisticClassificationModel[FeaturesType, M]] + extends ClassificationModel[FeaturesType, M] with ProbabilisticClassifierParams { + + /** @group setParam */ + def setProbabilityCol(value: String): M = set(probabilityCol, value).asInstanceOf[M] + + /** + * Transforms dataset by reading from [[featuresCol]], and appending new columns as specified by + * parameters: + * - predicted labels as [[predictionCol]] of type [[Double]] + * - raw predictions (confidences) as [[rawPredictionCol]] of type [[Vector]] + * - probability of each class as [[probabilityCol]] of type [[Vector]]. + * + * @param dataset input dataset + * @param paramMap additional parameters, overwrite embedded params + * @return transformed dataset + */ + override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { + // This default implementation should be overridden as needed. + + // Check schema + transformSchema(dataset.schema, paramMap, logging = true) + val map = this.paramMap ++ paramMap + + // Prepare model + val tmpModel = if (paramMap.size != 0) { + val tmpModel = this.copy() + Params.inheritValues(paramMap, parent, tmpModel) + tmpModel + } else { + this + } + + val (numColsOutput, outputData) = + ClassificationModel.transformColumnsImpl[FeaturesType](dataset, tmpModel, map) + + // Output selected columns only. + if (map(probabilityCol) != "") { + // output probabilities + val features2probs: FeaturesType => Vector = (features) => { + tmpModel.predictProbabilities(features) + } + outputData.withColumn(map(probabilityCol), + callUDF(features2probs, new VectorUDT, col(map(featuresCol)))) + } else { + if (numColsOutput == 0) { + this.logWarning(s"$uid: ProbabilisticClassificationModel.transform() was called as NOOP" + + " since no output columns were set.") + } + outputData + } + } + + /** + * :: DeveloperApi :: + * + * Predict the probability of each class given the features. + * These predictions are also called class conditional probabilities. + * + * WARNING: Not all models output well-calibrated probability estimates! These probabilities + * should be treated as confidences, not precise probabilities. + * + * This internal method is used to implement [[transform()]] and output [[probabilityCol]]. + */ + @DeveloperApi + protected def predictProbabilities(features: FeaturesType): Vector +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala index 12473cb2b5719..2360f4479f1c2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala @@ -18,44 +18,53 @@ package org.apache.spark.ml.evaluation import org.apache.spark.annotation.AlphaComponent -import org.apache.spark.ml._ +import org.apache.spark.ml.Evaluator import org.apache.spark.ml.param._ import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics -import org.apache.spark.sql.{Row, SchemaRDD} +import org.apache.spark.mllib.linalg.{Vector, VectorUDT} +import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.types.DoubleType + /** * :: AlphaComponent :: + * * Evaluator for binary classification, which expects two input columns: score and label. */ @AlphaComponent class BinaryClassificationEvaluator extends Evaluator with Params - with HasScoreCol with HasLabelCol { + with HasRawPredictionCol with HasLabelCol { - /** param for metric name in evaluation */ + /** + * param for metric name in evaluation + * @group param + */ val metricName: Param[String] = new Param(this, "metricName", "metric name in evaluation (areaUnderROC|areaUnderPR)", Some("areaUnderROC")) + + /** @group getParam */ def getMetricName: String = get(metricName) + + /** @group setParam */ def setMetricName(value: String): this.type = set(metricName, value) - def setScoreCol(value: String): this.type = set(scoreCol, value) + /** @group setParam */ + def setScoreCol(value: String): this.type = set(rawPredictionCol, value) + + /** @group setParam */ def setLabelCol(value: String): this.type = set(labelCol, value) - override def evaluate(dataset: SchemaRDD, paramMap: ParamMap): Double = { + override def evaluate(dataset: DataFrame, paramMap: ParamMap): Double = { val map = this.paramMap ++ paramMap val schema = dataset.schema - val scoreType = schema(map(scoreCol)).dataType - require(scoreType == DoubleType, - s"Score column ${map(scoreCol)} must be double type but found $scoreType") - val labelType = schema(map(labelCol)).dataType - require(labelType == DoubleType, - s"Label column ${map(labelCol)} must be double type but found $labelType") + checkInputColumn(schema, map(rawPredictionCol), new VectorUDT) + checkInputColumn(schema, map(labelCol), DoubleType) - import dataset.sqlContext._ - val scoreAndLabels = dataset.select(map(scoreCol).attr, map(labelCol).attr) - .map { case Row(score: Double, label: Double) => - (score, label) + // TODO: When dataset metadata has been implemented, check rawPredictionCol vector length = 2. + val scoreAndLabels = dataset.select(map(rawPredictionCol), map(labelCol)) + .map { case Row(rawPrediction: Vector, label: Double) => + (rawPrediction(1), label) } val metrics = new BinaryClassificationMetrics(scoreAndLabels) val metric = map(metricName) match { diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala index 0956062643f23..6131ba8832691 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala @@ -31,11 +31,18 @@ import org.apache.spark.sql.types.DataType @AlphaComponent class HashingTF extends UnaryTransformer[Iterable[_], Vector, HashingTF] { - /** number of features */ + /** + * number of features + * @group param + */ val numFeatures = new IntParam(this, "numFeatures", "number of features", Some(1 << 18)) - def setNumFeatures(value: Int) = set(numFeatures, value) + + /** @group getParam */ def getNumFeatures: Int = get(numFeatures) + /** @group setParam */ + def setNumFeatures(value: Int) = set(numFeatures, value) + override protected def createTransformFunc(paramMap: ParamMap): Iterable[_] => Vector = { val hashingTF = new feature.HashingTF(paramMap(numFeatures)) hashingTF.transform diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala index 72825f6e02182..1142aa4f8e73d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala @@ -23,8 +23,7 @@ import org.apache.spark.ml.param._ import org.apache.spark.mllib.feature import org.apache.spark.mllib.linalg.{Vector, VectorUDT} import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.analysis.Star -import org.apache.spark.sql.catalyst.dsl._ +import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{StructField, StructType} /** @@ -40,24 +39,23 @@ private[feature] trait StandardScalerParams extends Params with HasInputCol with @AlphaComponent class StandardScaler extends Estimator[StandardScalerModel] with StandardScalerParams { + /** @group setParam */ def setInputCol(value: String): this.type = set(inputCol, value) + + /** @group setParam */ def setOutputCol(value: String): this.type = set(outputCol, value) - override def fit(dataset: SchemaRDD, paramMap: ParamMap): StandardScalerModel = { + override def fit(dataset: DataFrame, paramMap: ParamMap): StandardScalerModel = { transformSchema(dataset.schema, paramMap, logging = true) - import dataset.sqlContext._ val map = this.paramMap ++ paramMap - val input = dataset.select(map(inputCol).attr) - .map { case Row(v: Vector) => - v - } + val input = dataset.select(map(inputCol)).map { case Row(v: Vector) => v } val scaler = new feature.StandardScaler().fit(input) val model = new StandardScalerModel(this, map, scaler) Params.inheritValues(map, this, model) model } - private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { + override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { val map = this.paramMap ++ paramMap val inputType = schema(map(inputCol)).dataType require(inputType.isInstanceOf[VectorUDT], @@ -80,20 +78,20 @@ class StandardScalerModel private[ml] ( scaler: feature.StandardScalerModel) extends Model[StandardScalerModel] with StandardScalerParams { + /** @group setParam */ def setInputCol(value: String): this.type = set(inputCol, value) + + /** @group setParam */ def setOutputCol(value: String): this.type = set(outputCol, value) - override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = { + override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { transformSchema(dataset.schema, paramMap, logging = true) - import dataset.sqlContext._ val map = this.paramMap ++ paramMap - val scale: (Vector) => Vector = (v) => { - scaler.transform(v) - } - dataset.select(Star(None), scale.call(map(inputCol).attr) as map(outputCol)) + val scale = udf((v: Vector) => { scaler.transform(v) } : Vector) + dataset.withColumn(map(outputCol), scale(col(map(inputCol)))) } - private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { + override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { val map = this.paramMap ++ paramMap val inputType = schema(map(inputCol)).dataType require(inputType.isInstanceOf[VectorUDT], diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala index e622a5cf9e6f3..0b1f90daa7d8e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala @@ -29,11 +29,11 @@ import org.apache.spark.sql.types.{DataType, StringType, ArrayType} @AlphaComponent class Tokenizer extends UnaryTransformer[String, Seq[String], Tokenizer] { - protected override def createTransformFunc(paramMap: ParamMap): String => Seq[String] = { + override protected def createTransformFunc(paramMap: ParamMap): String => Seq[String] = { _.toLowerCase.split("\\s") } - protected override def validateInputType(inputType: DataType): Unit = { + override protected def validateInputType(inputType: DataType): Unit = { require(inputType == StringType, s"Input type must be string type but got $inputType.") } diff --git a/mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala new file mode 100644 index 0000000000000..dfb89cc8d4af3 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala @@ -0,0 +1,241 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.impl.estimator + +import org.apache.spark.annotation.{AlphaComponent, DeveloperApi} +import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.ml.param._ +import org.apache.spark.mllib.linalg.{VectorUDT, Vector} +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types.{DataType, DoubleType, StructType} + + +/** + * :: DeveloperApi :: + * + * Trait for parameters for prediction (regression and classification). + * + * NOTE: This is currently private[spark] but will be made public later once it is stabilized. + */ +@DeveloperApi +private[spark] trait PredictorParams extends Params + with HasLabelCol with HasFeaturesCol with HasPredictionCol { + + /** + * Validates and transforms the input schema with the provided param map. + * @param schema input schema + * @param paramMap additional parameters + * @param fitting whether this is in fitting + * @param featuresDataType SQL DataType for FeaturesType. + * E.g., [[org.apache.spark.mllib.linalg.VectorUDT]] for vector features. + * @return output schema + */ + protected def validateAndTransformSchema( + schema: StructType, + paramMap: ParamMap, + fitting: Boolean, + featuresDataType: DataType): StructType = { + val map = this.paramMap ++ paramMap + // TODO: Support casting Array[Double] and Array[Float] to Vector when FeaturesType = Vector + checkInputColumn(schema, map(featuresCol), featuresDataType) + if (fitting) { + // TODO: Allow other numeric types + checkInputColumn(schema, map(labelCol), DoubleType) + } + addOutputColumn(schema, map(predictionCol), DoubleType) + } +} + +/** + * :: AlphaComponent :: + * + * Abstraction for prediction problems (regression and classification). + * + * @tparam FeaturesType Type of features. + * E.g., [[org.apache.spark.mllib.linalg.VectorUDT]] for vector features. + * @tparam Learner Specialization of this class. If you subclass this type, use this type + * parameter to specify the concrete type. + * @tparam M Specialization of [[PredictionModel]]. If you subclass this type, use this type + * parameter to specify the concrete type for the corresponding model. + * + * NOTE: This is currently private[spark] but will be made public later once it is stabilized. + */ +@AlphaComponent +private[spark] abstract class Predictor[ + FeaturesType, + Learner <: Predictor[FeaturesType, Learner, M], + M <: PredictionModel[FeaturesType, M]] + extends Estimator[M] with PredictorParams { + + /** @group setParam */ + def setLabelCol(value: String): Learner = set(labelCol, value).asInstanceOf[Learner] + + /** @group setParam */ + def setFeaturesCol(value: String): Learner = set(featuresCol, value).asInstanceOf[Learner] + + /** @group setParam */ + def setPredictionCol(value: String): Learner = set(predictionCol, value).asInstanceOf[Learner] + + override def fit(dataset: DataFrame, paramMap: ParamMap): M = { + // This handles a few items such as schema validation. + // Developers only need to implement train(). + transformSchema(dataset.schema, paramMap, logging = true) + val map = this.paramMap ++ paramMap + val model = train(dataset, map) + Params.inheritValues(map, this, model) // copy params to model + model + } + + /** + * :: DeveloperApi :: + * + * Train a model using the given dataset and parameters. + * Developers can implement this instead of [[fit()]] to avoid dealing with schema validation + * and copying parameters into the model. + * + * @param dataset Training dataset + * @param paramMap Parameter map. Unlike [[fit()]]'s paramMap, this paramMap has already + * been combined with the embedded ParamMap. + * @return Fitted model + */ + @DeveloperApi + protected def train(dataset: DataFrame, paramMap: ParamMap): M + + /** + * :: DeveloperApi :: + * + * Returns the SQL DataType corresponding to the FeaturesType type parameter. + * + * This is used by [[validateAndTransformSchema()]]. + * This workaround is needed since SQL has different APIs for Scala and Java. + * + * The default value is VectorUDT, but it may be overridden if FeaturesType is not Vector. + */ + @DeveloperApi + protected def featuresDataType: DataType = new VectorUDT + + override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { + validateAndTransformSchema(schema, paramMap, fitting = true, featuresDataType) + } + + /** + * Extract [[labelCol]] and [[featuresCol]] from the given dataset, + * and put it in an RDD with strong types. + */ + protected def extractLabeledPoints(dataset: DataFrame, paramMap: ParamMap): RDD[LabeledPoint] = { + val map = this.paramMap ++ paramMap + dataset.select(map(labelCol), map(featuresCol)) + .map { case Row(label: Double, features: Vector) => + LabeledPoint(label, features) + } + } +} + +/** + * :: AlphaComponent :: + * + * Abstraction for a model for prediction tasks (regression and classification). + * + * @tparam FeaturesType Type of features. + * E.g., [[org.apache.spark.mllib.linalg.VectorUDT]] for vector features. + * @tparam M Specialization of [[PredictionModel]]. If you subclass this type, use this type + * parameter to specify the concrete type for the corresponding model. + * + * NOTE: This is currently private[spark] but will be made public later once it is stabilized. + */ +@AlphaComponent +private[spark] abstract class PredictionModel[FeaturesType, M <: PredictionModel[FeaturesType, M]] + extends Model[M] with PredictorParams { + + /** @group setParam */ + def setFeaturesCol(value: String): M = set(featuresCol, value).asInstanceOf[M] + + /** @group setParam */ + def setPredictionCol(value: String): M = set(predictionCol, value).asInstanceOf[M] + + /** + * :: DeveloperApi :: + * + * Returns the SQL DataType corresponding to the FeaturesType type parameter. + * + * This is used by [[validateAndTransformSchema()]]. + * This workaround is needed since SQL has different APIs for Scala and Java. + * + * The default value is VectorUDT, but it may be overridden if FeaturesType is not Vector. + */ + @DeveloperApi + protected def featuresDataType: DataType = new VectorUDT + + override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { + validateAndTransformSchema(schema, paramMap, fitting = false, featuresDataType) + } + + /** + * Transforms dataset by reading from [[featuresCol]], calling [[predict()]], and storing + * the predictions as a new column [[predictionCol]]. + * + * @param dataset input dataset + * @param paramMap additional parameters, overwrite embedded params + * @return transformed dataset with [[predictionCol]] of type [[Double]] + */ + override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { + // This default implementation should be overridden as needed. + + // Check schema + transformSchema(dataset.schema, paramMap, logging = true) + val map = this.paramMap ++ paramMap + + // Prepare model + val tmpModel = if (paramMap.size != 0) { + val tmpModel = this.copy() + Params.inheritValues(paramMap, parent, tmpModel) + tmpModel + } else { + this + } + + if (map(predictionCol) != "") { + val pred: FeaturesType => Double = (features) => { + tmpModel.predict(features) + } + dataset.withColumn(map(predictionCol), callUDF(pred, DoubleType, col(map(featuresCol)))) + } else { + this.logWarning(s"$uid: Predictor.transform() was called as NOOP" + + " since no output columns were set.") + dataset + } + } + + /** + * :: DeveloperApi :: + * + * Predict label for the given features. + * This internal method is used to implement [[transform()]] and output [[predictionCol]]. + */ + @DeveloperApi + protected def predict(features: FeaturesType): Double + + /** + * Create a copy of the model. + * The copy is shallow, except for the embedded paramMap, which gets a deep copy. + */ + protected def copy(): M +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/package.scala b/mllib/src/main/scala/org/apache/spark/ml/package.scala index 51cd48c90432a..b45bd1499b72e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/package.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/package.scala @@ -20,5 +20,19 @@ package org.apache.spark /** * Spark ML is an ALPHA component that adds a new set of machine learning APIs to let users quickly * assemble and configure practical machine learning pipelines. + * + * @groupname param Parameters + * @groupdesc param A list of (hyper-)parameter keys this algorithm can take. Users can set and get + * the parameter values through setters and getters, respectively. + * @groupprio param -5 + * + * @groupname setParam Parameter setters + * @groupprio setParam 5 + * + * @groupname getParam Parameter getters + * @groupprio getParam 6 + * + * @groupname Ungrouped Members + * @groupprio Ungrouped 0 */ package object ml diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index 04f9cfb1bfc2f..17ece897a6c55 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -22,8 +22,10 @@ import scala.collection.mutable import java.lang.reflect.Modifier -import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.annotation.{AlphaComponent, DeveloperApi} import org.apache.spark.ml.Identifiable +import org.apache.spark.sql.types.{DataType, StructField, StructType} + /** * :: AlphaComponent :: @@ -65,37 +67,47 @@ class Param[T] ( // specialize primitive-typed params because Java doesn't recognize scala.Double, scala.Int, ... /** Specialized version of [[Param[Double]]] for Java. */ -class DoubleParam(parent: Params, name: String, doc: String, defaultValue: Option[Double] = None) +class DoubleParam(parent: Params, name: String, doc: String, defaultValue: Option[Double]) extends Param[Double](parent, name, doc, defaultValue) { + def this(parent: Params, name: String, doc: String) = this(parent, name, doc, None) + override def w(value: Double): ParamPair[Double] = super.w(value) } /** Specialized version of [[Param[Int]]] for Java. */ -class IntParam(parent: Params, name: String, doc: String, defaultValue: Option[Int] = None) +class IntParam(parent: Params, name: String, doc: String, defaultValue: Option[Int]) extends Param[Int](parent, name, doc, defaultValue) { + def this(parent: Params, name: String, doc: String) = this(parent, name, doc, None) + override def w(value: Int): ParamPair[Int] = super.w(value) } /** Specialized version of [[Param[Float]]] for Java. */ -class FloatParam(parent: Params, name: String, doc: String, defaultValue: Option[Float] = None) +class FloatParam(parent: Params, name: String, doc: String, defaultValue: Option[Float]) extends Param[Float](parent, name, doc, defaultValue) { + def this(parent: Params, name: String, doc: String) = this(parent, name, doc, None) + override def w(value: Float): ParamPair[Float] = super.w(value) } /** Specialized version of [[Param[Long]]] for Java. */ -class LongParam(parent: Params, name: String, doc: String, defaultValue: Option[Long] = None) +class LongParam(parent: Params, name: String, doc: String, defaultValue: Option[Long]) extends Param[Long](parent, name, doc, defaultValue) { + def this(parent: Params, name: String, doc: String) = this(parent, name, doc, None) + override def w(value: Long): ParamPair[Long] = super.w(value) } /** Specialized version of [[Param[Boolean]]] for Java. */ -class BooleanParam(parent: Params, name: String, doc: String, defaultValue: Option[Boolean] = None) +class BooleanParam(parent: Params, name: String, doc: String, defaultValue: Option[Boolean]) extends Param[Boolean](parent, name, doc, defaultValue) { + def this(parent: Params, name: String, doc: String) = this(parent, name, doc, None) + override def w(value: Boolean): ParamPair[Boolean] = super.w(value) } @@ -158,16 +170,23 @@ trait Params extends Identifiable with Serializable { /** * Sets a parameter in the embedded param map. */ - private[ml] def set[T](param: Param[T], value: T): this.type = { + protected def set[T](param: Param[T], value: T): this.type = { require(param.parent.eq(this)) paramMap.put(param.asInstanceOf[Param[Any]], value) this } + /** + * Sets a parameter (by name) in the embedded param map. + */ + private[ml] def set(param: String, value: Any): this.type = { + set(getParam(param), value) + } + /** * Gets the value of a parameter in the embedded param map. */ - private[ml] def get[T](param: Param[T]): T = { + protected def get[T](param: Param[T]): T = { require(param.parent.eq(this)) paramMap(param) } @@ -176,9 +195,40 @@ trait Params extends Identifiable with Serializable { * Internal param map. */ protected val paramMap: ParamMap = ParamMap.empty + + /** + * Check whether the given schema contains an input column. + * @param colName Parameter name for the input column. + * @param dataType SQL DataType of the input column. + */ + protected def checkInputColumn(schema: StructType, colName: String, dataType: DataType): Unit = { + val actualDataType = schema(colName).dataType + require(actualDataType.equals(dataType), + s"Input column $colName must be of type $dataType" + + s" but was actually $actualDataType. Column param description: ${getParam(colName)}") + } + + protected def addOutputColumn( + schema: StructType, + colName: String, + dataType: DataType): StructType = { + if (colName.length == 0) return schema + val fieldNames = schema.fieldNames + require(!fieldNames.contains(colName), s"Prediction column $colName already exists.") + val outputFields = schema.fields ++ Seq(StructField(colName, dataType, nullable = false)) + StructType(outputFields) + } } -private[ml] object Params { +/** + * :: DeveloperApi :: + * + * Helper functionality for developers. + * + * NOTE: This is currently private[spark] but will be made public later once it is stabilized. + */ +@DeveloperApi +private[spark] object Params { /** * Copies parameter values from the parent estimator to the child model it produced. @@ -272,7 +322,7 @@ class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) exten def copy: ParamMap = new ParamMap(map.clone()) override def toString: String = { - map.map { case (param, value) => + map.toSeq.sortBy(_._1.name).map { case (param, value) => s"\t${param.parent.uid}-${param.name}: $value" }.mkString("{\n", ",\n", "\n}") } @@ -286,7 +336,6 @@ class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) exten new ParamMap(this.map ++ other.map) } - /** * Adds all parameters from the input param map into this param map. */ @@ -304,6 +353,11 @@ class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) exten ParamPair(param, value) } } + + /** + * Number of param pairs in this set. + */ + def size: Int = map.size } object ParamMap { diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala index ef141d3eb2b06..1a70322b4cace 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala @@ -17,58 +17,124 @@ package org.apache.spark.ml.param +/* NOTE TO DEVELOPERS: + * If you mix these parameter traits into your algorithm, please add a setter method as well + * so that users may use a builder pattern: + * val myLearner = new MyLearner().setParam1(x).setParam2(y)... + */ + private[ml] trait HasRegParam extends Params { - /** param for regularization parameter */ + /** + * param for regularization parameter + * @group param + */ val regParam: DoubleParam = new DoubleParam(this, "regParam", "regularization parameter") + + /** @group getParam */ def getRegParam: Double = get(regParam) } private[ml] trait HasMaxIter extends Params { - /** param for max number of iterations */ + /** + * param for max number of iterations + * @group param + */ val maxIter: IntParam = new IntParam(this, "maxIter", "max number of iterations") + + /** @group getParam */ def getMaxIter: Int = get(maxIter) } private[ml] trait HasFeaturesCol extends Params { - /** param for features column name */ + /** + * param for features column name + * @group param + */ val featuresCol: Param[String] = new Param(this, "featuresCol", "features column name", Some("features")) + + /** @group getParam */ def getFeaturesCol: String = get(featuresCol) } private[ml] trait HasLabelCol extends Params { - /** param for label column name */ + /** + * param for label column name + * @group param + */ val labelCol: Param[String] = new Param(this, "labelCol", "label column name", Some("label")) - def getLabelCol: String = get(labelCol) -} -private[ml] trait HasScoreCol extends Params { - /** param for score column name */ - val scoreCol: Param[String] = new Param(this, "scoreCol", "score column name", Some("score")) - def getScoreCol: String = get(scoreCol) + /** @group getParam */ + def getLabelCol: String = get(labelCol) } private[ml] trait HasPredictionCol extends Params { - /** param for prediction column name */ + /** + * param for prediction column name + * @group param + */ val predictionCol: Param[String] = new Param(this, "predictionCol", "prediction column name", Some("prediction")) + + /** @group getParam */ def getPredictionCol: String = get(predictionCol) } +private[ml] trait HasRawPredictionCol extends Params { + /** + * param for raw prediction column name + * @group param + */ + val rawPredictionCol: Param[String] = + new Param(this, "rawPredictionCol", "raw prediction (a.k.a. confidence) column name", + Some("rawPrediction")) + + /** @group getParam */ + def getRawPredictionCol: String = get(rawPredictionCol) +} + +private[ml] trait HasProbabilityCol extends Params { + /** + * param for predicted class conditional probabilities column name + * @group param + */ + val probabilityCol: Param[String] = + new Param(this, "probabilityCol", "column name for predicted class conditional probabilities", + Some("probability")) + + /** @group getParam */ + def getProbabilityCol: String = get(probabilityCol) +} + private[ml] trait HasThreshold extends Params { - /** param for threshold in (binary) prediction */ + /** + * param for threshold in (binary) prediction + * @group param + */ val threshold: DoubleParam = new DoubleParam(this, "threshold", "threshold in prediction") + + /** @group getParam */ def getThreshold: Double = get(threshold) } private[ml] trait HasInputCol extends Params { - /** param for input column name */ + /** + * param for input column name + * @group param + */ val inputCol: Param[String] = new Param(this, "inputCol", "input column name") + + /** @group getParam */ def getInputCol: String = get(inputCol) } private[ml] trait HasOutputCol extends Params { - /** param for output column name */ + /** + * param for output column name + * @group param + */ val outputCol: Param[String] = new Param(this, "outputCol", "output column name") + + /** @group getParam */ def getOutputCol: String = get(outputCol) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala new file mode 100644 index 0000000000000..7bb69df65362b --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -0,0 +1,1163 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.recommendation + +import java.{util => ju} + +import scala.collection.mutable +import scala.reflect.ClassTag +import scala.util.Sorting +import scala.util.hashing.byteswap64 + +import com.github.fommil.netlib.BLAS.{getInstance => blas} +import com.github.fommil.netlib.LAPACK.{getInstance => lapack} +import org.jblas.DoubleMatrix +import org.netlib.util.intW + +import org.apache.spark.{Logging, Partitioner} +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.ml.param._ +import org.apache.spark.mllib.optimization.NNLS +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types.{DoubleType, FloatType, IntegerType, StructField, StructType} +import org.apache.spark.storage.StorageLevel +import org.apache.spark.util.Utils +import org.apache.spark.util.collection.{OpenHashMap, OpenHashSet, SortDataFormat, Sorter} +import org.apache.spark.util.random.XORShiftRandom + +/** + * Common params for ALS. + */ +private[recommendation] trait ALSParams extends Params with HasMaxIter with HasRegParam + with HasPredictionCol { + + /** + * Param for rank of the matrix factorization. + * @group param + */ + val rank = new IntParam(this, "rank", "rank of the factorization", Some(10)) + + /** @group getParam */ + def getRank: Int = get(rank) + + /** + * Param for number of user blocks. + * @group param + */ + val numUserBlocks = new IntParam(this, "numUserBlocks", "number of user blocks", Some(10)) + + /** @group getParam */ + def getNumUserBlocks: Int = get(numUserBlocks) + + /** + * Param for number of item blocks. + * @group param + */ + val numItemBlocks = + new IntParam(this, "numItemBlocks", "number of item blocks", Some(10)) + + /** @group getParam */ + def getNumItemBlocks: Int = get(numItemBlocks) + + /** + * Param to decide whether to use implicit preference. + * @group param + */ + val implicitPrefs = + new BooleanParam(this, "implicitPrefs", "whether to use implicit preference", Some(false)) + + /** @group getParam */ + def getImplicitPrefs: Boolean = get(implicitPrefs) + + /** + * Param for the alpha parameter in the implicit preference formulation. + * @group param + */ + val alpha = new DoubleParam(this, "alpha", "alpha for implicit preference", Some(1.0)) + + /** @group getParam */ + def getAlpha: Double = get(alpha) + + /** + * Param for the column name for user ids. + * @group param + */ + val userCol = new Param[String](this, "userCol", "column name for user ids", Some("user")) + + /** @group getParam */ + def getUserCol: String = get(userCol) + + /** + * Param for the column name for item ids. + * @group param + */ + val itemCol = + new Param[String](this, "itemCol", "column name for item ids", Some("item")) + + /** @group getParam */ + def getItemCol: String = get(itemCol) + + /** + * Param for the column name for ratings. + * @group param + */ + val ratingCol = new Param[String](this, "ratingCol", "column name for ratings", Some("rating")) + + /** @group getParam */ + def getRatingCol: String = get(ratingCol) + + /** + * Param for whether to apply nonnegativity constraints. + * @group param + */ + val nonnegative = new BooleanParam( + this, "nonnegative", "whether to use nonnegative constraint for least squares", Some(false)) + + /** @group getParam */ + val getNonnegative: Boolean = get(nonnegative) + + /** + * Validates and transforms the input schema. + * @param schema input schema + * @param paramMap extra params + * @return output schema + */ + protected def validateAndTransformSchema(schema: StructType, paramMap: ParamMap): StructType = { + val map = this.paramMap ++ paramMap + assert(schema(map(userCol)).dataType == IntegerType) + assert(schema(map(itemCol)).dataType== IntegerType) + val ratingType = schema(map(ratingCol)).dataType + assert(ratingType == FloatType || ratingType == DoubleType) + val predictionColName = map(predictionCol) + assert(!schema.fieldNames.contains(predictionColName), + s"Prediction column $predictionColName already exists.") + val newFields = schema.fields :+ StructField(map(predictionCol), FloatType, nullable = false) + StructType(newFields) + } +} + +/** + * Model fitted by ALS. + */ +class ALSModel private[ml] ( + override val parent: ALS, + override val fittingParamMap: ParamMap, + k: Int, + userFactors: RDD[(Int, Array[Float])], + itemFactors: RDD[(Int, Array[Float])]) + extends Model[ALSModel] with ALSParams { + + def setPredictionCol(value: String): this.type = set(predictionCol, value) + + override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { + import dataset.sqlContext.implicits._ + val map = this.paramMap ++ paramMap + val users = userFactors.toDF("id", "features") + val items = itemFactors.toDF("id", "features") + + // Register a UDF for DataFrame, and then + // create a new column named map(predictionCol) by running the predict UDF. + val predict = udf { (userFeatures: Seq[Float], itemFeatures: Seq[Float]) => + if (userFeatures != null && itemFeatures != null) { + blas.sdot(k, userFeatures.toArray, 1, itemFeatures.toArray, 1) + } else { + Float.NaN + } + } + dataset + .join(users, dataset(map(userCol)) === users("id"), "left") + .join(items, dataset(map(itemCol)) === items("id"), "left") + .select(dataset("*"), predict(users("features"), items("features")).as(map(predictionCol))) + } + + override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { + validateAndTransformSchema(schema, paramMap) + } +} + + +/** + * Alternating Least Squares (ALS) matrix factorization. + * + * ALS attempts to estimate the ratings matrix `R` as the product of two lower-rank matrices, + * `X` and `Y`, i.e. `X * Yt = R`. Typically these approximations are called 'factor' matrices. + * The general approach is iterative. During each iteration, one of the factor matrices is held + * constant, while the other is solved for using least squares. The newly-solved factor matrix is + * then held constant while solving for the other factor matrix. + * + * This is a blocked implementation of the ALS factorization algorithm that groups the two sets + * of factors (referred to as "users" and "products") into blocks and reduces communication by only + * sending one copy of each user vector to each product block on each iteration, and only for the + * product blocks that need that user's feature vector. This is achieved by pre-computing some + * information about the ratings matrix to determine the "out-links" of each user (which blocks of + * products it will contribute to) and "in-link" information for each product (which of the feature + * vectors it receives from each user block it will depend on). This allows us to send only an + * array of feature vectors between each user block and product block, and have the product block + * find the users' ratings and update the products based on these messages. + * + * For implicit preference data, the algorithm used is based on + * "Collaborative Filtering for Implicit Feedback Datasets", available at + * [[http://dx.doi.org/10.1109/ICDM.2008.22]], adapted for the blocked approach used here. + * + * Essentially instead of finding the low-rank approximations to the rating matrix `R`, + * this finds the approximations for a preference matrix `P` where the elements of `P` are 1 if + * r > 0 and 0 if r <= 0. The ratings then act as 'confidence' values related to strength of + * indicated user + * preferences rather than explicit ratings given to items. + */ +class ALS extends Estimator[ALSModel] with ALSParams { + + import org.apache.spark.ml.recommendation.ALS.Rating + + /** @group setParam */ + def setRank(value: Int): this.type = set(rank, value) + + /** @group setParam */ + def setNumUserBlocks(value: Int): this.type = set(numUserBlocks, value) + + /** @group setParam */ + def setNumItemBlocks(value: Int): this.type = set(numItemBlocks, value) + + /** @group setParam */ + def setImplicitPrefs(value: Boolean): this.type = set(implicitPrefs, value) + + /** @group setParam */ + def setAlpha(value: Double): this.type = set(alpha, value) + + /** @group setParam */ + def setUserCol(value: String): this.type = set(userCol, value) + + /** @group setParam */ + def setItemCol(value: String): this.type = set(itemCol, value) + + /** @group setParam */ + def setRatingCol(value: String): this.type = set(ratingCol, value) + + /** @group setParam */ + def setPredictionCol(value: String): this.type = set(predictionCol, value) + + /** @group setParam */ + def setMaxIter(value: Int): this.type = set(maxIter, value) + + /** @group setParam */ + def setRegParam(value: Double): this.type = set(regParam, value) + + /** @group setParam */ + def setNonnegative(value: Boolean): this.type = set(nonnegative, value) + + /** + * Sets both numUserBlocks and numItemBlocks to the specific value. + * @group setParam + */ + def setNumBlocks(value: Int): this.type = { + setNumUserBlocks(value) + setNumItemBlocks(value) + this + } + + setMaxIter(20) + setRegParam(1.0) + + override def fit(dataset: DataFrame, paramMap: ParamMap): ALSModel = { + val map = this.paramMap ++ paramMap + val ratings = dataset + .select(col(map(userCol)), col(map(itemCol)), col(map(ratingCol)).cast(FloatType)) + .map { row => + Rating(row.getInt(0), row.getInt(1), row.getFloat(2)) + } + val (userFactors, itemFactors) = ALS.train(ratings, rank = map(rank), + numUserBlocks = map(numUserBlocks), numItemBlocks = map(numItemBlocks), + maxIter = map(maxIter), regParam = map(regParam), implicitPrefs = map(implicitPrefs), + alpha = map(alpha), nonnegative = map(nonnegative)) + val model = new ALSModel(this, map, map(rank), userFactors, itemFactors) + Params.inheritValues(map, this, model) + model + } + + override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { + validateAndTransformSchema(schema, paramMap) + } +} + +/** + * :: DeveloperApi :: + * An implementation of ALS that supports generic ID types, specialized for Int and Long. This is + * exposed as a developer API for users who do need other ID types. But it is not recommended + * because it increases the shuffle size and memory requirement during training. For simplicity, + * users and items must have the same type. The number of distinct users/items should be smaller + * than 2 billion. + */ +@DeveloperApi +object ALS extends Logging { + + /** Rating class for better code readability. */ + case class Rating[@specialized(Int, Long) ID](user: ID, item: ID, rating: Float) + + /** Trait for least squares solvers applied to the normal equation. */ + private[recommendation] trait LeastSquaresNESolver extends Serializable { + /** Solves a least squares problem (possibly with other constraints). */ + def solve(ne: NormalEquation, lambda: Double): Array[Float] + } + + /** Cholesky solver for least square problems. */ + private[recommendation] class CholeskySolver extends LeastSquaresNESolver { + + private val upper = "U" + + /** + * Solves a least squares problem with L2 regularization: + * + * min norm(A x - b)^2^ + lambda * n * norm(x)^2^ + * + * @param ne a [[NormalEquation]] instance that contains AtA, Atb, and n (number of instances) + * @param lambda regularization constant, which will be scaled by n + * @return the solution x + */ + override def solve(ne: NormalEquation, lambda: Double): Array[Float] = { + val k = ne.k + // Add scaled lambda to the diagonals of AtA. + val scaledlambda = lambda * ne.n + var i = 0 + var j = 2 + while (i < ne.triK) { + ne.ata(i) += scaledlambda + i += j + j += 1 + } + val info = new intW(0) + lapack.dppsv(upper, k, 1, ne.ata, ne.atb, k, info) + val code = info.`val` + assert(code == 0, s"lapack.dppsv returned $code.") + val x = new Array[Float](k) + i = 0 + while (i < k) { + x(i) = ne.atb(i).toFloat + i += 1 + } + ne.reset() + x + } + } + + /** NNLS solver. */ + private[recommendation] class NNLSSolver extends LeastSquaresNESolver { + private var rank: Int = -1 + private var workspace: NNLS.Workspace = _ + private var ata: DoubleMatrix = _ + private var initialized: Boolean = false + + private def initialize(rank: Int): Unit = { + if (!initialized) { + this.rank = rank + workspace = NNLS.createWorkspace(rank) + ata = new DoubleMatrix(rank, rank) + initialized = true + } else { + require(this.rank == rank) + } + } + + /** + * Solves a nonnegative least squares problem with L2 regularizatin: + * + * min_x_ norm(A x - b)^2^ + lambda * n * norm(x)^2^ + * subject to x >= 0 + */ + override def solve(ne: NormalEquation, lambda: Double): Array[Float] = { + val rank = ne.k + initialize(rank) + fillAtA(ne.ata, lambda * ne.n) + val x = NNLS.solve(ata, new DoubleMatrix(rank, 1, ne.atb: _*), workspace) + ne.reset() + x.map(x => x.toFloat) + } + + /** + * Given a triangular matrix in the order of fillXtX above, compute the full symmetric square + * matrix that it represents, storing it into destMatrix. + */ + private def fillAtA(triAtA: Array[Double], lambda: Double) { + var i = 0 + var pos = 0 + var a = 0.0 + val data = ata.data + while (i < rank) { + var j = 0 + while (j <= i) { + a = triAtA(pos) + data(i * rank + j) = a + data(j * rank + i) = a + pos += 1 + j += 1 + } + data(i * rank + i) += lambda + i += 1 + } + } + } + + /** Representing a normal equation (ALS' subproblem). */ + private[recommendation] class NormalEquation(val k: Int) extends Serializable { + + /** Number of entries in the upper triangular part of a k-by-k matrix. */ + val triK = k * (k + 1) / 2 + /** A^T^ * A */ + val ata = new Array[Double](triK) + /** A^T^ * b */ + val atb = new Array[Double](k) + /** Number of observations. */ + var n = 0 + + private val da = new Array[Double](k) + private val upper = "U" + + private def copyToDouble(a: Array[Float]): Unit = { + var i = 0 + while (i < k) { + da(i) = a(i) + i += 1 + } + } + + /** Adds an observation. */ + def add(a: Array[Float], b: Float): this.type = { + require(a.length == k) + copyToDouble(a) + blas.dspr(upper, k, 1.0, da, 1, ata) + blas.daxpy(k, b.toDouble, da, 1, atb, 1) + n += 1 + this + } + + /** + * Adds an observation with implicit feedback. Note that this does not increment the counter. + */ + def addImplicit(a: Array[Float], b: Float, alpha: Double): this.type = { + require(a.length == k) + // Extension to the original paper to handle b < 0. confidence is a function of |b| instead + // so that it is never negative. + val confidence = 1.0 + alpha * math.abs(b) + copyToDouble(a) + blas.dspr(upper, k, confidence - 1.0, da, 1, ata) + // For b <= 0, the corresponding preference is 0. So the term below is only added for b > 0. + if (b > 0) { + blas.daxpy(k, confidence, da, 1, atb, 1) + } + this + } + + /** Merges another normal equation object. */ + def merge(other: NormalEquation): this.type = { + require(other.k == k) + blas.daxpy(ata.length, 1.0, other.ata, 1, ata, 1) + blas.daxpy(atb.length, 1.0, other.atb, 1, atb, 1) + n += other.n + this + } + + /** Resets everything to zero, which should be called after each solve. */ + def reset(): Unit = { + ju.Arrays.fill(ata, 0.0) + ju.Arrays.fill(atb, 0.0) + n = 0 + } + } + + /** + * Implementation of the ALS algorithm. + */ + def train[ID: ClassTag]( // scalastyle:ignore + ratings: RDD[Rating[ID]], + rank: Int = 10, + numUserBlocks: Int = 10, + numItemBlocks: Int = 10, + maxIter: Int = 10, + regParam: Double = 1.0, + implicitPrefs: Boolean = false, + alpha: Double = 1.0, + nonnegative: Boolean = false, + intermediateRDDStorageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK, + finalRDDStorageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK, + seed: Long = 0L)( + implicit ord: Ordering[ID]): (RDD[(ID, Array[Float])], RDD[(ID, Array[Float])]) = { + require(intermediateRDDStorageLevel != StorageLevel.NONE, + "ALS is not designed to run without persisting intermediate RDDs.") + val sc = ratings.sparkContext + val userPart = new ALSPartitioner(numUserBlocks) + val itemPart = new ALSPartitioner(numItemBlocks) + val userLocalIndexEncoder = new LocalIndexEncoder(userPart.numPartitions) + val itemLocalIndexEncoder = new LocalIndexEncoder(itemPart.numPartitions) + val solver = if (nonnegative) new NNLSSolver else new CholeskySolver + val blockRatings = partitionRatings(ratings, userPart, itemPart) + .persist(intermediateRDDStorageLevel) + val (userInBlocks, userOutBlocks) = + makeBlocks("user", blockRatings, userPart, itemPart, intermediateRDDStorageLevel) + // materialize blockRatings and user blocks + userOutBlocks.count() + val swappedBlockRatings = blockRatings.map { + case ((userBlockId, itemBlockId), RatingBlock(userIds, itemIds, localRatings)) => + ((itemBlockId, userBlockId), RatingBlock(itemIds, userIds, localRatings)) + } + val (itemInBlocks, itemOutBlocks) = + makeBlocks("item", swappedBlockRatings, itemPart, userPart, intermediateRDDStorageLevel) + // materialize item blocks + itemOutBlocks.count() + val seedGen = new XORShiftRandom(seed) + var userFactors = initialize(userInBlocks, rank, seedGen.nextLong()) + var itemFactors = initialize(itemInBlocks, rank, seedGen.nextLong()) + if (implicitPrefs) { + for (iter <- 1 to maxIter) { + userFactors.setName(s"userFactors-$iter").persist(intermediateRDDStorageLevel) + val previousItemFactors = itemFactors + itemFactors = computeFactors(userFactors, userOutBlocks, itemInBlocks, rank, regParam, + userLocalIndexEncoder, implicitPrefs, alpha, solver) + previousItemFactors.unpersist() + if (sc.checkpointDir.isDefined && (iter % 3 == 0)) { + itemFactors.checkpoint() + } + itemFactors.setName(s"itemFactors-$iter").persist(intermediateRDDStorageLevel) + val previousUserFactors = userFactors + userFactors = computeFactors(itemFactors, itemOutBlocks, userInBlocks, rank, regParam, + itemLocalIndexEncoder, implicitPrefs, alpha, solver) + previousUserFactors.unpersist() + } + } else { + for (iter <- 0 until maxIter) { + itemFactors = computeFactors(userFactors, userOutBlocks, itemInBlocks, rank, regParam, + userLocalIndexEncoder, solver = solver) + userFactors = computeFactors(itemFactors, itemOutBlocks, userInBlocks, rank, regParam, + itemLocalIndexEncoder, solver = solver) + } + } + val userIdAndFactors = userInBlocks + .mapValues(_.srcIds) + .join(userFactors) + .mapPartitions({ items => + items.flatMap { case (_, (ids, factors)) => + ids.view.zip(factors) + } + // Preserve the partitioning because IDs are consistent with the partitioners in userInBlocks + // and userFactors. + }, preservesPartitioning = true) + .setName("userFactors") + .persist(finalRDDStorageLevel) + val itemIdAndFactors = itemInBlocks + .mapValues(_.srcIds) + .join(itemFactors) + .mapPartitions({ items => + items.flatMap { case (_, (ids, factors)) => + ids.view.zip(factors) + } + }, preservesPartitioning = true) + .setName("itemFactors") + .persist(finalRDDStorageLevel) + if (finalRDDStorageLevel != StorageLevel.NONE) { + userIdAndFactors.count() + itemFactors.unpersist() + itemIdAndFactors.count() + userInBlocks.unpersist() + userOutBlocks.unpersist() + itemInBlocks.unpersist() + itemOutBlocks.unpersist() + blockRatings.unpersist() + } + (userIdAndFactors, itemIdAndFactors) + } + + /** + * Factor block that stores factors (Array[Float]) in an Array. + */ + private type FactorBlock = Array[Array[Float]] + + /** + * Out-link block that stores, for each dst (item/user) block, which src (user/item) factors to + * send. For example, outLinkBlock(0) contains the local indices (not the original src IDs) of the + * src factors in this block to send to dst block 0. + */ + private type OutBlock = Array[Array[Int]] + + /** + * In-link block for computing src (user/item) factors. This includes the original src IDs + * of the elements within this block as well as encoded dst (item/user) indices and corresponding + * ratings. The dst indices are in the form of (blockId, localIndex), which are not the original + * dst IDs. To compute src factors, we expect receiving dst factors that match the dst indices. + * For example, if we have an in-link record + * + * {srcId: 0, dstBlockId: 2, dstLocalIndex: 3, rating: 5.0}, + * + * and assume that the dst factors are stored as dstFactors: Map[Int, Array[Array[Float]]], which + * is a blockId to dst factors map, the corresponding dst factor of the record is dstFactor(2)(3). + * + * We use a CSC-like (compressed sparse column) format to store the in-link information. So we can + * compute src factors one after another using only one normal equation instance. + * + * @param srcIds src ids (ordered) + * @param dstPtrs dst pointers. Elements in range [dstPtrs(i), dstPtrs(i+1)) of dst indices and + * ratings are associated with srcIds(i). + * @param dstEncodedIndices encoded dst indices + * @param ratings ratings + * + * @see [[LocalIndexEncoder]] + */ + private[recommendation] case class InBlock[@specialized(Int, Long) ID: ClassTag]( + srcIds: Array[ID], + dstPtrs: Array[Int], + dstEncodedIndices: Array[Int], + ratings: Array[Float]) { + /** Size of the block. */ + def size: Int = ratings.length + require(dstEncodedIndices.length == size) + require(dstPtrs.length == srcIds.length + 1) + } + + /** + * Initializes factors randomly given the in-link blocks. + * + * @param inBlocks in-link blocks + * @param rank rank + * @return initialized factor blocks + */ + private def initialize[ID]( + inBlocks: RDD[(Int, InBlock[ID])], + rank: Int, + seed: Long): RDD[(Int, FactorBlock)] = { + // Choose a unit vector uniformly at random from the unit sphere, but from the + // "first quadrant" where all elements are nonnegative. This can be done by choosing + // elements distributed as Normal(0,1) and taking the absolute value, and then normalizing. + // This appears to create factorizations that have a slightly better reconstruction + // (<1%) compared picking elements uniformly at random in [0,1]. + inBlocks.map { case (srcBlockId, inBlock) => + val random = new XORShiftRandom(byteswap64(seed ^ srcBlockId)) + val factors = Array.fill(inBlock.srcIds.length) { + val factor = Array.fill(rank)(random.nextGaussian().toFloat) + val nrm = blas.snrm2(rank, factor, 1) + blas.sscal(rank, 1.0f / nrm, factor, 1) + factor + } + (srcBlockId, factors) + } + } + + /** + * A rating block that contains src IDs, dst IDs, and ratings, stored in primitive arrays. + */ + private[recommendation] case class RatingBlock[@specialized(Int, Long) ID: ClassTag]( + srcIds: Array[ID], + dstIds: Array[ID], + ratings: Array[Float]) { + /** Size of the block. */ + def size: Int = srcIds.length + require(dstIds.length == srcIds.length) + require(ratings.length == srcIds.length) + } + + /** + * Builder for [[RatingBlock]]. [[mutable.ArrayBuilder]] is used to avoid boxing/unboxing. + */ + private[recommendation] class RatingBlockBuilder[@specialized(Int, Long) ID: ClassTag] + extends Serializable { + + private val srcIds = mutable.ArrayBuilder.make[ID] + private val dstIds = mutable.ArrayBuilder.make[ID] + private val ratings = mutable.ArrayBuilder.make[Float] + var size = 0 + + /** Adds a rating. */ + def add(r: Rating[ID]): this.type = { + size += 1 + srcIds += r.user + dstIds += r.item + ratings += r.rating + this + } + + /** Merges another [[RatingBlockBuilder]]. */ + def merge(other: RatingBlock[ID]): this.type = { + size += other.srcIds.length + srcIds ++= other.srcIds + dstIds ++= other.dstIds + ratings ++= other.ratings + this + } + + /** Builds a [[RatingBlock]]. */ + def build(): RatingBlock[ID] = { + RatingBlock[ID](srcIds.result(), dstIds.result(), ratings.result()) + } + } + + /** + * Partitions raw ratings into blocks. + * + * @param ratings raw ratings + * @param srcPart partitioner for src IDs + * @param dstPart partitioner for dst IDs + * + * @return an RDD of rating blocks in the form of ((srcBlockId, dstBlockId), ratingBlock) + */ + private def partitionRatings[ID: ClassTag]( + ratings: RDD[Rating[ID]], + srcPart: Partitioner, + dstPart: Partitioner): RDD[((Int, Int), RatingBlock[ID])] = { + + /* The implementation produces the same result as the following but generates less objects. + + ratings.map { r => + ((srcPart.getPartition(r.user), dstPart.getPartition(r.item)), r) + }.aggregateByKey(new RatingBlockBuilder)( + seqOp = (b, r) => b.add(r), + combOp = (b0, b1) => b0.merge(b1.build())) + .mapValues(_.build()) + */ + + val numPartitions = srcPart.numPartitions * dstPart.numPartitions + ratings.mapPartitions { iter => + val builders = Array.fill(numPartitions)(new RatingBlockBuilder[ID]) + iter.flatMap { r => + val srcBlockId = srcPart.getPartition(r.user) + val dstBlockId = dstPart.getPartition(r.item) + val idx = srcBlockId + srcPart.numPartitions * dstBlockId + val builder = builders(idx) + builder.add(r) + if (builder.size >= 2048) { // 2048 * (3 * 4) = 24k + builders(idx) = new RatingBlockBuilder + Iterator.single(((srcBlockId, dstBlockId), builder.build())) + } else { + Iterator.empty + } + } ++ { + builders.view.zipWithIndex.filter(_._1.size > 0).map { case (block, idx) => + val srcBlockId = idx % srcPart.numPartitions + val dstBlockId = idx / srcPart.numPartitions + ((srcBlockId, dstBlockId), block.build()) + } + } + }.groupByKey().mapValues { blocks => + val builder = new RatingBlockBuilder[ID] + blocks.foreach(builder.merge) + builder.build() + }.setName("ratingBlocks") + } + + /** + * Builder for uncompressed in-blocks of (srcId, dstEncodedIndex, rating) tuples. + * @param encoder encoder for dst indices + */ + private[recommendation] class UncompressedInBlockBuilder[@specialized(Int, Long) ID: ClassTag]( + encoder: LocalIndexEncoder)( + implicit ord: Ordering[ID]) { + + private val srcIds = mutable.ArrayBuilder.make[ID] + private val dstEncodedIndices = mutable.ArrayBuilder.make[Int] + private val ratings = mutable.ArrayBuilder.make[Float] + + /** + * Adds a dst block of (srcId, dstLocalIndex, rating) tuples. + * + * @param dstBlockId dst block ID + * @param srcIds original src IDs + * @param dstLocalIndices dst local indices + * @param ratings ratings + */ + def add( + dstBlockId: Int, + srcIds: Array[ID], + dstLocalIndices: Array[Int], + ratings: Array[Float]): this.type = { + val sz = srcIds.length + require(dstLocalIndices.length == sz) + require(ratings.length == sz) + this.srcIds ++= srcIds + this.ratings ++= ratings + var j = 0 + while (j < sz) { + this.dstEncodedIndices += encoder.encode(dstBlockId, dstLocalIndices(j)) + j += 1 + } + this + } + + /** Builds a [[UncompressedInBlock]]. */ + def build(): UncompressedInBlock[ID] = { + new UncompressedInBlock(srcIds.result(), dstEncodedIndices.result(), ratings.result()) + } + } + + /** + * A block of (srcId, dstEncodedIndex, rating) tuples stored in primitive arrays. + */ + private[recommendation] class UncompressedInBlock[@specialized(Int, Long) ID: ClassTag]( + val srcIds: Array[ID], + val dstEncodedIndices: Array[Int], + val ratings: Array[Float])( + implicit ord: Ordering[ID]) { + + /** Size the of block. */ + def length: Int = srcIds.length + + /** + * Compresses the block into an [[InBlock]]. The algorithm is the same as converting a + * sparse matrix from coordinate list (COO) format into compressed sparse column (CSC) format. + * Sorting is done using Spark's built-in Timsort to avoid generating too many objects. + */ + def compress(): InBlock[ID] = { + val sz = length + assert(sz > 0, "Empty in-link block should not exist.") + sort() + val uniqueSrcIdsBuilder = mutable.ArrayBuilder.make[ID] + val dstCountsBuilder = mutable.ArrayBuilder.make[Int] + var preSrcId = srcIds(0) + uniqueSrcIdsBuilder += preSrcId + var curCount = 1 + var i = 1 + var j = 0 + while (i < sz) { + val srcId = srcIds(i) + if (srcId != preSrcId) { + uniqueSrcIdsBuilder += srcId + dstCountsBuilder += curCount + preSrcId = srcId + j += 1 + curCount = 0 + } + curCount += 1 + i += 1 + } + dstCountsBuilder += curCount + val uniqueSrcIds = uniqueSrcIdsBuilder.result() + val numUniqueSrdIds = uniqueSrcIds.length + val dstCounts = dstCountsBuilder.result() + val dstPtrs = new Array[Int](numUniqueSrdIds + 1) + var sum = 0 + i = 0 + while (i < numUniqueSrdIds) { + sum += dstCounts(i) + i += 1 + dstPtrs(i) = sum + } + InBlock(uniqueSrcIds, dstPtrs, dstEncodedIndices, ratings) + } + + private def sort(): Unit = { + val sz = length + // Since there might be interleaved log messages, we insert a unique id for easy pairing. + val sortId = Utils.random.nextInt() + logDebug(s"Start sorting an uncompressed in-block of size $sz. (sortId = $sortId)") + val start = System.nanoTime() + val sorter = new Sorter(new UncompressedInBlockSort[ID]) + sorter.sort(this, 0, length, Ordering[KeyWrapper[ID]]) + val duration = (System.nanoTime() - start) / 1e9 + logDebug(s"Sorting took $duration seconds. (sortId = $sortId)") + } + } + + /** + * A wrapper that holds a primitive key. + * + * @see [[UncompressedInBlockSort]] + */ + private class KeyWrapper[@specialized(Int, Long) ID: ClassTag]( + implicit ord: Ordering[ID]) extends Ordered[KeyWrapper[ID]] { + + var key: ID = _ + + override def compare(that: KeyWrapper[ID]): Int = { + ord.compare(key, that.key) + } + + def setKey(key: ID): this.type = { + this.key = key + this + } + } + + /** + * [[SortDataFormat]] of [[UncompressedInBlock]] used by [[Sorter]]. + */ + private class UncompressedInBlockSort[@specialized(Int, Long) ID: ClassTag]( + implicit ord: Ordering[ID]) + extends SortDataFormat[KeyWrapper[ID], UncompressedInBlock[ID]] { + + override def newKey(): KeyWrapper[ID] = new KeyWrapper() + + override def getKey( + data: UncompressedInBlock[ID], + pos: Int, + reuse: KeyWrapper[ID]): KeyWrapper[ID] = { + if (reuse == null) { + new KeyWrapper().setKey(data.srcIds(pos)) + } else { + reuse.setKey(data.srcIds(pos)) + } + } + + override def getKey( + data: UncompressedInBlock[ID], + pos: Int): KeyWrapper[ID] = { + getKey(data, pos, null) + } + + private def swapElements[@specialized(Int, Float) T]( + data: Array[T], + pos0: Int, + pos1: Int): Unit = { + val tmp = data(pos0) + data(pos0) = data(pos1) + data(pos1) = tmp + } + + override def swap(data: UncompressedInBlock[ID], pos0: Int, pos1: Int): Unit = { + swapElements(data.srcIds, pos0, pos1) + swapElements(data.dstEncodedIndices, pos0, pos1) + swapElements(data.ratings, pos0, pos1) + } + + override def copyRange( + src: UncompressedInBlock[ID], + srcPos: Int, + dst: UncompressedInBlock[ID], + dstPos: Int, + length: Int): Unit = { + System.arraycopy(src.srcIds, srcPos, dst.srcIds, dstPos, length) + System.arraycopy(src.dstEncodedIndices, srcPos, dst.dstEncodedIndices, dstPos, length) + System.arraycopy(src.ratings, srcPos, dst.ratings, dstPos, length) + } + + override def allocate(length: Int): UncompressedInBlock[ID] = { + new UncompressedInBlock( + new Array[ID](length), new Array[Int](length), new Array[Float](length)) + } + + override def copyElement( + src: UncompressedInBlock[ID], + srcPos: Int, + dst: UncompressedInBlock[ID], + dstPos: Int): Unit = { + dst.srcIds(dstPos) = src.srcIds(srcPos) + dst.dstEncodedIndices(dstPos) = src.dstEncodedIndices(srcPos) + dst.ratings(dstPos) = src.ratings(srcPos) + } + } + + /** + * Creates in-blocks and out-blocks from rating blocks. + * @param prefix prefix for in/out-block names + * @param ratingBlocks rating blocks + * @param srcPart partitioner for src IDs + * @param dstPart partitioner for dst IDs + * @return (in-blocks, out-blocks) + */ + private def makeBlocks[ID: ClassTag]( + prefix: String, + ratingBlocks: RDD[((Int, Int), RatingBlock[ID])], + srcPart: Partitioner, + dstPart: Partitioner, + storageLevel: StorageLevel)( + implicit srcOrd: Ordering[ID]): (RDD[(Int, InBlock[ID])], RDD[(Int, OutBlock)]) = { + val inBlocks = ratingBlocks.map { + case ((srcBlockId, dstBlockId), RatingBlock(srcIds, dstIds, ratings)) => + // The implementation is a faster version of + // val dstIdToLocalIndex = dstIds.toSet.toSeq.sorted.zipWithIndex.toMap + val start = System.nanoTime() + val dstIdSet = new OpenHashSet[ID](1 << 20) + dstIds.foreach(dstIdSet.add) + val sortedDstIds = new Array[ID](dstIdSet.size) + var i = 0 + var pos = dstIdSet.nextPos(0) + while (pos != -1) { + sortedDstIds(i) = dstIdSet.getValue(pos) + pos = dstIdSet.nextPos(pos + 1) + i += 1 + } + assert(i == dstIdSet.size) + Sorting.quickSort(sortedDstIds) + val dstIdToLocalIndex = new OpenHashMap[ID, Int](sortedDstIds.length) + i = 0 + while (i < sortedDstIds.length) { + dstIdToLocalIndex.update(sortedDstIds(i), i) + i += 1 + } + logDebug( + "Converting to local indices took " + (System.nanoTime() - start) / 1e9 + " seconds.") + val dstLocalIndices = dstIds.map(dstIdToLocalIndex.apply) + (srcBlockId, (dstBlockId, srcIds, dstLocalIndices, ratings)) + }.groupByKey(new ALSPartitioner(srcPart.numPartitions)) + .mapValues { iter => + val builder = + new UncompressedInBlockBuilder[ID](new LocalIndexEncoder(dstPart.numPartitions)) + iter.foreach { case (dstBlockId, srcIds, dstLocalIndices, ratings) => + builder.add(dstBlockId, srcIds, dstLocalIndices, ratings) + } + builder.build().compress() + }.setName(prefix + "InBlocks") + .persist(storageLevel) + val outBlocks = inBlocks.mapValues { case InBlock(srcIds, dstPtrs, dstEncodedIndices, _) => + val encoder = new LocalIndexEncoder(dstPart.numPartitions) + val activeIds = Array.fill(dstPart.numPartitions)(mutable.ArrayBuilder.make[Int]) + var i = 0 + val seen = new Array[Boolean](dstPart.numPartitions) + while (i < srcIds.length) { + var j = dstPtrs(i) + ju.Arrays.fill(seen, false) + while (j < dstPtrs(i + 1)) { + val dstBlockId = encoder.blockId(dstEncodedIndices(j)) + if (!seen(dstBlockId)) { + activeIds(dstBlockId) += i // add the local index in this out-block + seen(dstBlockId) = true + } + j += 1 + } + i += 1 + } + activeIds.map { x => + x.result() + } + }.setName(prefix + "OutBlocks") + .persist(storageLevel) + (inBlocks, outBlocks) + } + + /** + * Compute dst factors by constructing and solving least square problems. + * + * @param srcFactorBlocks src factors + * @param srcOutBlocks src out-blocks + * @param dstInBlocks dst in-blocks + * @param rank rank + * @param regParam regularization constant + * @param srcEncoder encoder for src local indices + * @param implicitPrefs whether to use implicit preference + * @param alpha the alpha constant in the implicit preference formulation + * @param solver solver for least squares problems + * + * @return dst factors + */ + private def computeFactors[ID]( + srcFactorBlocks: RDD[(Int, FactorBlock)], + srcOutBlocks: RDD[(Int, OutBlock)], + dstInBlocks: RDD[(Int, InBlock[ID])], + rank: Int, + regParam: Double, + srcEncoder: LocalIndexEncoder, + implicitPrefs: Boolean = false, + alpha: Double = 1.0, + solver: LeastSquaresNESolver): RDD[(Int, FactorBlock)] = { + val numSrcBlocks = srcFactorBlocks.partitions.length + val YtY = if (implicitPrefs) Some(computeYtY(srcFactorBlocks, rank)) else None + val srcOut = srcOutBlocks.join(srcFactorBlocks).flatMap { + case (srcBlockId, (srcOutBlock, srcFactors)) => + srcOutBlock.view.zipWithIndex.map { case (activeIndices, dstBlockId) => + (dstBlockId, (srcBlockId, activeIndices.map(idx => srcFactors(idx)))) + } + } + val merged = srcOut.groupByKey(new ALSPartitioner(dstInBlocks.partitions.length)) + dstInBlocks.join(merged).mapValues { + case (InBlock(dstIds, srcPtrs, srcEncodedIndices, ratings), srcFactors) => + val sortedSrcFactors = new Array[FactorBlock](numSrcBlocks) + srcFactors.foreach { case (srcBlockId, factors) => + sortedSrcFactors(srcBlockId) = factors + } + val dstFactors = new Array[Array[Float]](dstIds.length) + var j = 0 + val ls = new NormalEquation(rank) + while (j < dstIds.length) { + ls.reset() + if (implicitPrefs) { + ls.merge(YtY.get) + } + var i = srcPtrs(j) + while (i < srcPtrs(j + 1)) { + val encoded = srcEncodedIndices(i) + val blockId = srcEncoder.blockId(encoded) + val localIndex = srcEncoder.localIndex(encoded) + val srcFactor = sortedSrcFactors(blockId)(localIndex) + val rating = ratings(i) + if (implicitPrefs) { + ls.addImplicit(srcFactor, rating, alpha) + } else { + ls.add(srcFactor, rating) + } + i += 1 + } + dstFactors(j) = solver.solve(ls, regParam) + j += 1 + } + dstFactors + } + } + + /** + * Computes the Gramian matrix of user or item factors, which is only used in implicit preference. + * Caching of the input factors is handled in [[ALS#train]]. + */ + private def computeYtY(factorBlocks: RDD[(Int, FactorBlock)], rank: Int): NormalEquation = { + factorBlocks.values.aggregate(new NormalEquation(rank))( + seqOp = (ne, factors) => { + factors.foreach(ne.add(_, 0.0f)) + ne + }, + combOp = (ne1, ne2) => ne1.merge(ne2)) + } + + /** + * Encoder for storing (blockId, localIndex) into a single integer. + * + * We use the leading bits (including the sign bit) to store the block id and the rest to store + * the local index. This is based on the assumption that users/items are approximately evenly + * partitioned. With this assumption, we should be able to encode two billion distinct values. + * + * @param numBlocks number of blocks + */ + private[recommendation] class LocalIndexEncoder(numBlocks: Int) extends Serializable { + + require(numBlocks > 0, s"numBlocks must be positive but found $numBlocks.") + + private[this] final val numLocalIndexBits = + math.min(java.lang.Integer.numberOfLeadingZeros(numBlocks - 1), 31) + private[this] final val localIndexMask = (1 << numLocalIndexBits) - 1 + + /** Encodes a (blockId, localIndex) into a single integer. */ + def encode(blockId: Int, localIndex: Int): Int = { + require(blockId < numBlocks) + require((localIndex & ~localIndexMask) == 0) + (blockId << numLocalIndexBits) | localIndex + } + + /** Gets the block id from an encoded index. */ + @inline + def blockId(encoded: Int): Int = { + encoded >>> numLocalIndexBits + } + + /** Gets the local index from an encoded index. */ + @inline + def localIndex(encoded: Int): Int = { + encoded & localIndexMask + } + } + + /** + * Partitioner used by ALS. We requires that getPartition is a projection. That is, for any key k, + * we have getPartition(getPartition(k)) = getPartition(k). Since the the default HashPartitioner + * satisfies this requirement, we simply use a type alias here. + */ + private[recommendation] type ALSPartitioner = org.apache.spark.HashPartitioner +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala new file mode 100644 index 0000000000000..65f6627a0c351 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.regression + +import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.ml.param.{Params, ParamMap, HasMaxIter, HasRegParam} +import org.apache.spark.mllib.linalg.{BLAS, Vector} +import org.apache.spark.mllib.regression.LinearRegressionWithSGD +import org.apache.spark.sql.DataFrame +import org.apache.spark.storage.StorageLevel + + +/** + * Params for linear regression. + */ +private[regression] trait LinearRegressionParams extends RegressorParams + with HasRegParam with HasMaxIter + + +/** + * :: AlphaComponent :: + * + * Linear regression. + */ +@AlphaComponent +class LinearRegression extends Regressor[Vector, LinearRegression, LinearRegressionModel] + with LinearRegressionParams { + + setRegParam(0.1) + setMaxIter(100) + + /** @group setParam */ + def setRegParam(value: Double): this.type = set(regParam, value) + + /** @group setParam */ + def setMaxIter(value: Int): this.type = set(maxIter, value) + + override protected def train(dataset: DataFrame, paramMap: ParamMap): LinearRegressionModel = { + // Extract columns from data. If dataset is persisted, do not persist oldDataset. + val oldDataset = extractLabeledPoints(dataset, paramMap) + val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE + if (handlePersistence) { + oldDataset.persist(StorageLevel.MEMORY_AND_DISK) + } + + // Train model + val lr = new LinearRegressionWithSGD() + lr.optimizer + .setRegParam(paramMap(regParam)) + .setNumIterations(paramMap(maxIter)) + val model = lr.run(oldDataset) + val lrm = new LinearRegressionModel(this, paramMap, model.weights, model.intercept) + + if (handlePersistence) { + oldDataset.unpersist() + } + lrm + } +} + +/** + * :: AlphaComponent :: + * + * Model produced by [[LinearRegression]]. + */ +@AlphaComponent +class LinearRegressionModel private[ml] ( + override val parent: LinearRegression, + override val fittingParamMap: ParamMap, + val weights: Vector, + val intercept: Double) + extends RegressionModel[Vector, LinearRegressionModel] + with LinearRegressionParams { + + override protected def predict(features: Vector): Double = { + BLAS.dot(features, weights) + intercept + } + + override protected def copy(): LinearRegressionModel = { + val m = new LinearRegressionModel(parent, fittingParamMap, weights, intercept) + Params.inheritValues(this.paramMap, this, m) + m + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/Regressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/Regressor.scala new file mode 100644 index 0000000000000..d679085eeafe1 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/Regressor.scala @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.regression + +import org.apache.spark.annotation.{DeveloperApi, AlphaComponent} +import org.apache.spark.ml.impl.estimator.{PredictionModel, Predictor, PredictorParams} + +/** + * :: DeveloperApi :: + * Params for regression. + * Currently empty, but may add functionality later. + * + * NOTE: This is currently private[spark] but will be made public later once it is stabilized. + */ +@DeveloperApi +private[spark] trait RegressorParams extends PredictorParams + +/** + * :: AlphaComponent :: + * + * Single-label regression + * + * @tparam FeaturesType Type of input features. E.g., [[org.apache.spark.mllib.linalg.Vector]] + * @tparam Learner Concrete Estimator type + * @tparam M Concrete Model type + * + * NOTE: This is currently private[spark] but will be made public later once it is stabilized. + */ +@AlphaComponent +private[spark] abstract class Regressor[ + FeaturesType, + Learner <: Regressor[FeaturesType, Learner, M], + M <: RegressionModel[FeaturesType, M]] + extends Predictor[FeaturesType, Learner, M] + with RegressorParams { + + // TODO: defaultEvaluator (follow-up PR) +} + +/** + * :: AlphaComponent :: + * + * Model produced by a [[Regressor]]. + * + * @tparam FeaturesType Type of input features. E.g., [[org.apache.spark.mllib.linalg.Vector]] + * @tparam M Concrete Model type. + * + * NOTE: This is currently private[spark] but will be made public later once it is stabilized. + */ +@AlphaComponent +private[spark] abstract class RegressionModel[FeaturesType, M <: RegressionModel[FeaturesType, M]] + extends PredictionModel[FeaturesType, M] with RegressorParams { + + /** + * :: DeveloperApi :: + * + * Predict real-valued label for the given features. + * This internal method is used to implement [[transform()]] and output [[predictionCol]]. + */ + @DeveloperApi + protected def predict(features: FeaturesType): Double + +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index 08fe99176424a..2eb1dac56f1e9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -24,29 +24,49 @@ import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml._ import org.apache.spark.ml.param.{IntParam, Param, ParamMap, Params} import org.apache.spark.mllib.util.MLUtils -import org.apache.spark.sql.SchemaRDD +import org.apache.spark.sql.DataFrame import org.apache.spark.sql.types.StructType /** * Params for [[CrossValidator]] and [[CrossValidatorModel]]. */ private[ml] trait CrossValidatorParams extends Params { - /** param for the estimator to be cross-validated */ + /** + * param for the estimator to be cross-validated + * @group param + */ val estimator: Param[Estimator[_]] = new Param(this, "estimator", "estimator for selection") + + /** @group getParam */ def getEstimator: Estimator[_] = get(estimator) - /** param for estimator param maps */ + /** + * param for estimator param maps + * @group param + */ val estimatorParamMaps: Param[Array[ParamMap]] = new Param(this, "estimatorParamMaps", "param maps for the estimator") + + /** @group getParam */ def getEstimatorParamMaps: Array[ParamMap] = get(estimatorParamMaps) - /** param for the evaluator for selection */ + /** + * param for the evaluator for selection + * @group param + */ val evaluator: Param[Evaluator] = new Param(this, "evaluator", "evaluator for selection") + + /** @group getParam */ def getEvaluator: Evaluator = get(evaluator) - /** param for number of folds for cross validation */ + /** + * param for number of folds for cross validation + * @group param + */ val numFolds: IntParam = new IntParam(this, "numFolds", "number of folds for cross validation", Some(3)) + + /** @group getParam */ def getNumFolds: Int = get(numFolds) } @@ -59,12 +79,19 @@ class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorP private val f2jBLAS = new F2jBLAS + /** @group setParam */ def setEstimator(value: Estimator[_]): this.type = set(estimator, value) + + /** @group setParam */ def setEstimatorParamMaps(value: Array[ParamMap]): this.type = set(estimatorParamMaps, value) + + /** @group setParam */ def setEvaluator(value: Evaluator): this.type = set(evaluator, value) + + /** @group setParam */ def setNumFolds(value: Int): this.type = set(numFolds, value) - override def fit(dataset: SchemaRDD, paramMap: ParamMap): CrossValidatorModel = { + override def fit(dataset: DataFrame, paramMap: ParamMap): CrossValidatorModel = { val map = this.paramMap ++ paramMap val schema = dataset.schema transformSchema(dataset.schema, paramMap, logging = true) @@ -74,13 +101,14 @@ class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorP val epm = map(estimatorParamMaps) val numModels = epm.size val metrics = new Array[Double](epm.size) - val splits = MLUtils.kFold(dataset, map(numFolds), 0) + val splits = MLUtils.kFold(dataset.rdd, map(numFolds), 0) splits.zipWithIndex.foreach { case ((training, validation), splitIndex) => - val trainingDataset = sqlCtx.applySchema(training, schema).cache() - val validationDataset = sqlCtx.applySchema(validation, schema).cache() + val trainingDataset = sqlCtx.createDataFrame(training, schema).cache() + val validationDataset = sqlCtx.createDataFrame(validation, schema).cache() // multi-model training logDebug(s"Train split $splitIndex with multiple sets of parameters.") val models = est.fit(trainingDataset, epm).asInstanceOf[Seq[Model[_]]] + trainingDataset.unpersist() var i = 0 while (i < numModels) { val metric = eval.evaluate(models(i).transform(validationDataset, epm(i)), map) @@ -88,6 +116,7 @@ class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorP metrics(i) += metric i += 1 } + validationDataset.unpersist() } f2jBLAS.dscal(numModels, 1.0 / map(numFolds), metrics, 1) logInfo(s"Average cross-validation metrics: ${metrics.toSeq}") @@ -100,7 +129,7 @@ class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorP cvModel } - private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { + override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { val map = this.paramMap ++ paramMap map(estimator).transformSchema(schema, paramMap) } @@ -117,11 +146,11 @@ class CrossValidatorModel private[ml] ( val bestModel: Model[_]) extends Model[CrossValidatorModel] with CrossValidatorParams { - override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = { + override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { bestModel.transform(dataset, paramMap) } - private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { + override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { bestModel.transformSchema(schema, paramMap) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index 555da8c7e7ab3..cbd87ea8aeb37 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -22,6 +22,7 @@ import java.nio.{ByteBuffer, ByteOrder} import java.util.{ArrayList => JArrayList, List => JList, Map => JMap} import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer import scala.language.existentials import scala.reflect.ClassTag @@ -40,22 +41,22 @@ import org.apache.spark.mllib.recommendation._ import org.apache.spark.mllib.regression._ import org.apache.spark.mllib.stat.{MultivariateStatisticalSummary, Statistics} import org.apache.spark.mllib.stat.correlation.CorrelationNames +import org.apache.spark.mllib.stat.distribution.MultivariateGaussian import org.apache.spark.mllib.stat.test.ChiSqTestResult -import org.apache.spark.mllib.tree.{RandomForest, DecisionTree} -import org.apache.spark.mllib.tree.configuration.{Algo, Strategy} +import org.apache.spark.mllib.tree.{GradientBoostedTrees, RandomForest, DecisionTree} +import org.apache.spark.mllib.tree.configuration.{BoostingStrategy, Algo, Strategy} import org.apache.spark.mllib.tree.impurity._ -import org.apache.spark.mllib.tree.model.{RandomForestModel, DecisionTreeModel} +import org.apache.spark.mllib.tree.loss.Losses +import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel, RandomForestModel, DecisionTreeModel} import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils /** - * :: DeveloperApi :: - * The Java stubs necessary for the Python mllib bindings. + * The Java stubs necessary for the Python mllib bindings. It is called by Py4J on the Python side. */ -@DeveloperApi -class PythonMLLibAPI extends Serializable { +private[python] class PythonMLLibAPI extends Serializable { /** @@ -259,19 +260,23 @@ class PythonMLLibAPI extends Serializable { } /** - * Java stub for Python mllib KMeans.train() + * Java stub for Python mllib KMeans.run() */ def trainKMeansModel( data: JavaRDD[Vector], k: Int, maxIterations: Int, runs: Int, - initializationMode: String): KMeansModel = { + initializationMode: String, + seed: java.lang.Long): KMeansModel = { val kMeansAlg = new KMeans() .setK(k) .setMaxIterations(maxIterations) .setRuns(runs) .setInitializationMode(initializationMode) + + if (seed != null) kMeansAlg.setSeed(seed) + try { kMeansAlg.run(data.rdd.persist(StorageLevel.MEMORY_AND_DISK)) } finally { @@ -279,6 +284,58 @@ class PythonMLLibAPI extends Serializable { } } + /** + * Java stub for Python mllib GaussianMixture.run() + * Returns a list containing weights, mean and covariance of each mixture component. + */ + def trainGaussianMixture( + data: JavaRDD[Vector], + k: Int, + convergenceTol: Double, + maxIterations: Int, + seed: java.lang.Long): JList[Object] = { + val gmmAlg = new GaussianMixture() + .setK(k) + .setConvergenceTol(convergenceTol) + .setMaxIterations(maxIterations) + + if (seed != null) gmmAlg.setSeed(seed) + + try { + val model = gmmAlg.run(data.rdd.persist(StorageLevel.MEMORY_AND_DISK)) + var wt = ArrayBuffer.empty[Double] + var mu = ArrayBuffer.empty[Vector] + var sigma = ArrayBuffer.empty[Matrix] + for (i <- 0 until model.k) { + wt += model.weights(i) + mu += model.gaussians(i).mu + sigma += model.gaussians(i).sigma + } + List(wt.toArray, mu.toArray, sigma.toArray).map(_.asInstanceOf[Object]).asJava + } finally { + data.rdd.unpersist(blocking = false) + } + } + + /** + * Java stub for Python mllib GaussianMixtureModel.predictSoft() + */ + def predictSoftGMM( + data: JavaRDD[Vector], + wt: Object, + mu: Array[Object], + si: Array[Object]): RDD[Array[Double]] = { + + val weight = wt.asInstanceOf[Array[Double]] + val mean = mu.map(_.asInstanceOf[DenseVector]) + val sigma = si.map(_.asInstanceOf[DenseMatrix]) + val gaussians = Array.tabulate(weight.length){ + i => new MultivariateGaussian(mean(i), sigma(i)) + } + val model = new GaussianMixtureModel(weight, gaussians) + model.predictSoft(data) + } + /** * A Wrapper of MatrixFactorizationModel to provide helpfer method for Python */ @@ -528,6 +585,35 @@ class PythonMLLibAPI extends Serializable { } } + /** + * Java stub for Python mllib GradientBoostedTrees.train(). + * This stub returns a handle to the Java object instead of the content of the Java object. + * Extra care needs to be taken in the Python code to ensure it gets freed on exit; + * see the Py4J documentation. + */ + def trainGradientBoostedTreesModel( + data: JavaRDD[LabeledPoint], + algoStr: String, + categoricalFeaturesInfo: JMap[Int, Int], + lossStr: String, + numIterations: Int, + learningRate: Double, + maxDepth: Int): GradientBoostedTreesModel = { + val boostingStrategy = BoostingStrategy.defaultParams(algoStr) + boostingStrategy.setLoss(Losses.fromString(lossStr)) + boostingStrategy.setNumIterations(numIterations) + boostingStrategy.setLearningRate(learningRate) + boostingStrategy.treeStrategy.setMaxDepth(maxDepth) + boostingStrategy.treeStrategy.categoricalFeaturesInfo = categoricalFeaturesInfo.asScala.toMap + + val cached = data.rdd.persist(StorageLevel.MEMORY_AND_DISK) + try { + GradientBoostedTrees.train(cached, boostingStrategy) + } finally { + cached.unpersist(blocking = false) + } + } + /** * Java stub for mllib Statistics.colStats(X: RDD[Vector]). * TODO figure out return type. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala index b7a1d90d24d72..35a0db76f3a8c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala @@ -17,6 +17,8 @@ package org.apache.spark.mllib.classification +import org.json4s.{DefaultFormats, JValue} + import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.linalg.Vector @@ -53,3 +55,15 @@ trait ClassificationModel extends Serializable { def predict(testData: JavaRDD[Vector]): JavaRDD[java.lang.Double] = predict(testData.rdd).toJavaRDD().asInstanceOf[JavaRDD[java.lang.Double]] } + +private[mllib] object ClassificationModel { + + /** + * Helper method for loading GLM classification model metadata. + * @return (numFeatures, numClasses) + */ + def getNumFeaturesClasses(metadata: JValue): (Int, Int) = { + implicit val formats = DefaultFormats + ((metadata \ "numFeatures").extract[Int], (metadata \ "numClasses").extract[Int]) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala index 94d757bc317ab..b787667b018e6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala @@ -17,31 +17,63 @@ package org.apache.spark.mllib.classification +import org.apache.spark.SparkContext import org.apache.spark.annotation.Experimental -import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.classification.impl.GLMClassificationModel +import org.apache.spark.mllib.linalg.BLAS.dot +import org.apache.spark.mllib.linalg.{DenseVector, Vector} import org.apache.spark.mllib.optimization._ import org.apache.spark.mllib.regression._ -import org.apache.spark.mllib.util.DataValidators +import org.apache.spark.mllib.util.{DataValidators, Saveable, Loader} import org.apache.spark.rdd.RDD + /** - * Classification model trained using Logistic Regression. + * Classification model trained using Multinomial/Binary Logistic Regression. * * @param weights Weights computed for every feature. - * @param intercept Intercept computed for this model. + * @param intercept Intercept computed for this model. (Only used in Binary Logistic Regression. + * In Multinomial Logistic Regression, the intercepts will not be a single value, + * so the intercepts will be part of the weights.) + * @param numFeatures the dimension of the features. + * @param numClasses the number of possible outcomes for k classes classification problem in + * Multinomial Logistic Regression. By default, it is binary logistic regression + * so numClasses will be set to 2. */ class LogisticRegressionModel ( override val weights: Vector, - override val intercept: Double) - extends GeneralizedLinearModel(weights, intercept) with ClassificationModel with Serializable { + override val intercept: Double, + val numFeatures: Int, + val numClasses: Int) + extends GeneralizedLinearModel(weights, intercept) with ClassificationModel with Serializable + with Saveable { + + if (numClasses == 2) { + require(weights.size == numFeatures, + s"LogisticRegressionModel with numClasses = 2 was given non-matching values:" + + s" numFeatures = $numFeatures, but weights.size = ${weights.size}") + } else { + val weightsSizeWithoutIntercept = (numClasses - 1) * numFeatures + val weightsSizeWithIntercept = (numClasses - 1) * (numFeatures + 1) + require(weights.size == weightsSizeWithoutIntercept || weights.size == weightsSizeWithIntercept, + s"LogisticRegressionModel.load with numClasses = $numClasses and numFeatures = $numFeatures" + + s" expected weights of length $weightsSizeWithoutIntercept (without intercept)" + + s" or $weightsSizeWithIntercept (with intercept)," + + s" but was given weights of length ${weights.size}") + } + + /** + * Constructs a [[LogisticRegressionModel]] with weights and intercept for binary classification. + */ + def this(weights: Vector, intercept: Double) = this(weights, intercept, weights.size, 2) private var threshold: Option[Double] = Some(0.5) /** * :: Experimental :: - * Sets the threshold that separates positive predictions from negative predictions. An example - * with prediction score greater than or equal to this threshold is identified as an positive, - * and negative otherwise. The default value is 0.5. + * Sets the threshold that separates positive predictions from negative predictions + * in Binary Logistic Regression. An example with prediction score greater than or equal to + * this threshold is identified as an positive, and negative otherwise. The default value is 0.5. */ @Experimental def setThreshold(threshold: Double): this.type = { @@ -49,6 +81,13 @@ class LogisticRegressionModel ( this } + /** + * :: Experimental :: + * Returns the threshold (if any) used for converting raw prediction scores into 0/1 predictions. + */ + @Experimental + def getThreshold: Option[Double] = threshold + /** * :: Experimental :: * Clears the threshold so that `predict` will output raw prediction scores. @@ -59,25 +98,108 @@ class LogisticRegressionModel ( this } - override protected def predictPoint(dataMatrix: Vector, weightMatrix: Vector, + override protected def predictPoint( + dataMatrix: Vector, + weightMatrix: Vector, intercept: Double) = { - val margin = weightMatrix.toBreeze.dot(dataMatrix.toBreeze) + intercept - val score = 1.0 / (1.0 + math.exp(-margin)) - threshold match { - case Some(t) => if (score > t) 1.0 else 0.0 - case None => score + require(dataMatrix.size == numFeatures) + + // If dataMatrix and weightMatrix have the same dimension, it's binary logistic regression. + if (numClasses == 2) { + require(numFeatures == weightMatrix.size) + val margin = dot(weightMatrix, dataMatrix) + intercept + val score = 1.0 / (1.0 + math.exp(-margin)) + threshold match { + case Some(t) => if (score > t) 1.0 else 0.0 + case None => score + } + } else { + val dataWithBiasSize = weightMatrix.size / (numClasses - 1) + + val weightsArray = weightMatrix match { + case dv: DenseVector => dv.values + case _ => + throw new IllegalArgumentException( + s"weights only supports dense vector but got type ${weightMatrix.getClass}.") + } + + val margins = (0 until numClasses - 1).map { i => + var margin = 0.0 + dataMatrix.foreachActive { (index, value) => + if (value != 0.0) margin += value * weightsArray((i * dataWithBiasSize) + index) + } + // Intercept is required to be added into margin. + if (dataMatrix.size + 1 == dataWithBiasSize) { + margin += weightsArray((i * dataWithBiasSize) + dataMatrix.size) + } + margin + } + + /** + * Find the one with maximum margins. If the maxMargin is negative, then the prediction + * result will be the first class. + * + * PS, if you want to compute the probabilities for each outcome instead of the outcome + * with maximum probability, remember to subtract the maxMargin from margins if maxMargin + * is positive to prevent overflow. + */ + var bestClass = 0 + var maxMargin = 0.0 + var i = 0 + while(i < margins.size) { + if (margins(i) > maxMargin) { + maxMargin = margins(i) + bestClass = i + 1 + } + i += 1 + } + bestClass.toDouble + } + } + + override def save(sc: SparkContext, path: String): Unit = { + GLMClassificationModel.SaveLoadV1_0.save(sc, path, this.getClass.getName, + numFeatures, numClasses, weights, intercept, threshold) + } + + override protected def formatVersion: String = "1.0" +} + +object LogisticRegressionModel extends Loader[LogisticRegressionModel] { + + override def load(sc: SparkContext, path: String): LogisticRegressionModel = { + val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path) + // Hard-code class name string in case it changes in the future + val classNameV1_0 = "org.apache.spark.mllib.classification.LogisticRegressionModel" + (loadedClassName, version) match { + case (className, "1.0") if className == classNameV1_0 => + val (numFeatures, numClasses) = ClassificationModel.getNumFeaturesClasses(metadata) + val data = GLMClassificationModel.SaveLoadV1_0.loadData(sc, path, classNameV1_0) + // numFeatures, numClasses, weights are checked in model initialization + val model = + new LogisticRegressionModel(data.weights, data.intercept, numFeatures, numClasses) + data.threshold match { + case Some(t) => model.setThreshold(t) + case None => model.clearThreshold() + } + model + case _ => throw new Exception( + s"LogisticRegressionModel.load did not recognize model with (className, format version):" + + s"($loadedClassName, $version). Supported:\n" + + s" ($classNameV1_0, 1.0)") } } } /** - * Train a classification model for Logistic Regression using Stochastic Gradient Descent. By - * default L2 regularization is used, which can be changed via - * [[LogisticRegressionWithSGD.optimizer]]. - * NOTE: Labels used in Logistic Regression should be {0, 1}. + * Train a classification model for Binary Logistic Regression + * using Stochastic Gradient Descent. By default L2 regularization is used, + * which can be changed via [[LogisticRegressionWithSGD.optimizer]]. + * NOTE: Labels used in Logistic Regression should be {0, 1, ..., k - 1} + * for k classes multi-label classification problem. * Using [[LogisticRegressionWithLBFGS]] is recommended over this. */ -class LogisticRegressionWithSGD private ( +class LogisticRegressionWithSGD private[mllib] ( private var stepSize: Double, private var numIterations: Int, private var regParam: Double, @@ -99,7 +221,7 @@ class LogisticRegressionWithSGD private ( */ def this() = this(1.0, 100, 0.01, 1.0) - override protected def createModel(weights: Vector, intercept: Double) = { + override protected[mllib] def createModel(weights: Vector, intercept: Double) = { new LogisticRegressionModel(weights, intercept) } } @@ -194,9 +316,10 @@ object LogisticRegressionWithSGD { } /** - * Train a classification model for Logistic Regression using Limited-memory BFGS. - * Standard feature scaling and L2 regularization are used by default. - * NOTE: Labels used in Logistic Regression should be {0, 1} + * Train a classification model for Multinomial/Binary Logistic Regression using + * Limited-memory BFGS. Standard feature scaling and L2 regularization are used by default. + * NOTE: Labels used in Logistic Regression should be {0, 1, ..., k - 1} + * for k classes multi-label classification problem. */ class LogisticRegressionWithLBFGS extends GeneralizedLinearAlgorithm[LogisticRegressionModel] with Serializable { @@ -205,9 +328,37 @@ class LogisticRegressionWithLBFGS override val optimizer = new LBFGS(new LogisticGradient, new SquaredL2Updater) - override protected val validators = List(DataValidators.binaryLabelValidator) + override protected val validators = List(multiLabelValidator) + + private def multiLabelValidator: RDD[LabeledPoint] => Boolean = { data => + if (numOfLinearPredictor > 1) { + DataValidators.multiLabelValidator(numOfLinearPredictor + 1)(data) + } else { + DataValidators.binaryLabelValidator(data) + } + } + + /** + * :: Experimental :: + * Set the number of possible outcomes for k classes classification problem in + * Multinomial Logistic Regression. + * By default, it is binary logistic regression so k will be set to 2. + */ + @Experimental + def setNumClasses(numClasses: Int): this.type = { + require(numClasses > 1) + numOfLinearPredictor = numClasses - 1 + if (numClasses > 2) { + optimizer.setGradient(new LogisticGradient(numClasses)) + } + this + } override protected def createModel(weights: Vector, intercept: Double) = { - new LogisticRegressionModel(weights, intercept) + if (numOfLinearPredictor == 1) { + new LogisticRegressionModel(weights, intercept) + } else { + new LogisticRegressionModel(weights, intercept, numFeatures, numOfLinearPredictor + 1) + } } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala index 4ec4bdf9f18a6..4269d80f028fa 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala @@ -20,12 +20,17 @@ package org.apache.spark.mllib.classification import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argmax => brzArgmax, sum => brzSum, Axis} import breeze.numerics.{exp => brzExp, log => brzLog} -import org.apache.spark.{SparkException, Logging} -import org.apache.spark.SparkContext._ +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods._ +import org.json4s.{DefaultFormats, JValue} + +import org.apache.spark.{Logging, SparkContext, SparkException} import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.classification.NaiveBayesModels.NaiveBayesModels +import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{DataFrame, SQLContext} /** @@ -34,6 +39,10 @@ import org.apache.spark.rdd.RDD object NaiveBayesModels extends Enumeration { type NaiveBayesModels = Value val Multinomial, Bernoulli = Value + + implicit def toString(model: NaiveBayesModels): String = { + model.toString + } } /** @@ -43,7 +52,7 @@ object NaiveBayesModels extends Enumeration { * @param pi log of class priors, whose dimension is C, number of labels * @param theta log of class conditional probabilities, whose dimension is C-by-D, * where D is number of features - * @param model The type of NB model to fit from the enumeration NaiveBayesModels, can be + * @param modelType The type of NB model to fit from the enumeration NaiveBayesModels, can be * Multinomial or Bernoulli */ @@ -51,12 +60,12 @@ class NaiveBayesModel private[mllib] ( val labels: Array[Double], val pi: Array[Double], val theta: Array[Array[Double]], - val model: NaiveBayesModels) extends ClassificationModel with Serializable { + val modelType: NaiveBayesModels) extends ClassificationModel with Serializable with Saveable { private val brzPi = new BDV[Double](pi) private val brzTheta = new BDM(theta(0).length, theta.length, theta.flatten).t - private val brzNegTheta: Option[BDM[Double]] = model match { + private val brzNegTheta: Option[BDM[Double]] = modelType match { case NaiveBayesModels.Multinomial => None case NaiveBayesModels.Bernoulli => val negTheta = brzLog((brzExp(brzTheta.copy) :*= (-1.0)) :+= 1.0) // log(1.0 - exp(x)) @@ -72,7 +81,7 @@ class NaiveBayesModel private[mllib] ( } override def predict(testData: Vector): Double = { - model match { + modelType match { case NaiveBayesModels.Multinomial => labels (brzArgmax (brzPi + brzTheta * testData.toBreeze) ) case NaiveBayesModels.Bernoulli => @@ -81,6 +90,91 @@ class NaiveBayesModel private[mllib] ( brzSum(brzNegTheta.get, Axis._1))) } } + + override def save(sc: SparkContext, path: String): Unit = { + val data = NaiveBayesModel.SaveLoadV1_0.Data(labels, pi, theta, modelType) + NaiveBayesModel.SaveLoadV1_0.save(sc, path, data) + } + + override protected def formatVersion: String = "1.0" +} + +object NaiveBayesModel extends Loader[NaiveBayesModel] { + + import org.apache.spark.mllib.util.Loader._ + + private object SaveLoadV1_0 { + + def thisFormatVersion = "1.0" + + /** Hard-code class name string in case it changes in the future */ + def thisClassName = "org.apache.spark.mllib.classification.NaiveBayesModel" + + /** Model data for model import/export */ + case class Data(labels: Array[Double], pi: Array[Double], theta: Array[Array[Double]], modelType: String) + + def save(sc: SparkContext, path: String, data: Data): Unit = { + val sqlContext = new SQLContext(sc) + import sqlContext.implicits._ + + // Create JSON metadata. + val metadata = compact(render( + ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ + ("numFeatures" -> data.theta(0).length) ~ ("numClasses" -> data.pi.length) ~ + ("modelType" -> data.modelType))) + sc.parallelize(Seq(metadata), 1).saveAsTextFile(metadataPath(path)) + + // Create Parquet data. + val dataRDD: DataFrame = sc.parallelize(Seq(data), 1).toDF() + dataRDD.saveAsParquetFile(dataPath(path)) + } + + def load(sc: SparkContext, path: String): NaiveBayesModel = { + val sqlContext = new SQLContext(sc) + // Load Parquet data. + val dataRDD = sqlContext.parquetFile(dataPath(path)) + // Check schema explicitly since erasure makes it hard to use match-case for checking. + checkSchema[Data](dataRDD.schema) + val dataArray = dataRDD.select("labels", "pi", "theta", "modelType").take(1) + assert(dataArray.size == 1, s"Unable to load NaiveBayesModel data from: ${dataPath(path)}") + val data = dataArray(0) + val labels = data.getAs[Seq[Double]](0).toArray + val pi = data.getAs[Seq[Double]](1).toArray + val theta = data.getAs[Seq[Seq[Double]]](2).map(_.toArray).toArray + val modelType: NaiveBayesModels = NaiveBayesModels.withName(data.getAs[String](3)) + new NaiveBayesModel(labels, pi, theta, modelType) + } + } + + override def load(sc: SparkContext, path: String): NaiveBayesModel = { + def getModelType(metadata: JValue): NaiveBayesModels = { + implicit val formats = DefaultFormats + NaiveBayesModels.withName((metadata \ "modelType").extract[String]) + } + val (loadedClassName, version, metadata) = loadMetadata(sc, path) + val classNameV1_0 = SaveLoadV1_0.thisClassName + (loadedClassName, version) match { + case (className, "1.0") if className == classNameV1_0 => + val (numFeatures, numClasses) = ClassificationModel.getNumFeaturesClasses(metadata) + val model = SaveLoadV1_0.load(sc, path) + assert(model.pi.size == numClasses, + s"NaiveBayesModel.load expected $numClasses classes," + + s" but class priors vector pi had ${model.pi.size} elements") + assert(model.theta.size == numClasses, + s"NaiveBayesModel.load expected $numClasses classes," + + s" but class conditionals array theta had ${model.theta.size} elements") + assert(model.theta.forall(_.size == numFeatures), + s"NaiveBayesModel.load expected $numFeatures features," + + s" but class conditionals array theta had elements of size:" + + s" ${model.theta.map(_.size).mkString(",")}") + assert(model.modelType == getModelType(metadata)) + model + case _ => throw new Exception( + s"NaiveBayesModel.load did not recognize model with (className, format version):" + + s"($loadedClassName, $version). Supported:\n" + + s" ($classNameV1_0, 1.0)") + } + } } /** @@ -92,7 +186,7 @@ class NaiveBayesModel private[mllib] ( * Bernoulli NB ([[http://tinyurl.com/p7c96j6]]). The input feature values must be nonnegative. */ class NaiveBayes private (private var lambda: Double, - var model: NaiveBayesModels) extends Serializable with Logging { + var modelType: NaiveBayesModels) extends Serializable with Logging { def this(lambda: Double) = this(lambda, NaiveBayesModels.Multinomial) @@ -106,7 +200,7 @@ class NaiveBayes private (private var lambda: Double, /** Set the model type. Default: Multinomial. */ def setModelType(model: NaiveBayesModels): NaiveBayes = { - this.model = model + this.modelType = model this } @@ -161,7 +255,7 @@ class NaiveBayes private (private var lambda: Double, aggregated.foreach { case (label, (n, sumTermFreqs)) => labels(i) = label pi(i) = math.log(n + lambda) - piLogDenom - val thetaLogDenom = model match { + val thetaLogDenom = modelType match { case NaiveBayesModels.Multinomial => math.log(brzSum(sumTermFreqs) + numFeatures * lambda) case NaiveBayesModels.Bernoulli => math.log(n + 2.0 * lambda) } @@ -173,7 +267,7 @@ class NaiveBayes private (private var lambda: Double, i += 1 } - new NaiveBayesModel(labels, pi, theta, model) + new NaiveBayesModel(labels, pi, theta, modelType) } } @@ -226,10 +320,10 @@ object NaiveBayes { * vector or a count vector. * @param lambda The smoothing parameter * - * @param model The type of NB model to fit from the enumeration NaiveBayesModels, can be + * @param modelType The type of NB model to fit from the enumeration NaiveBayesModels, can be * Multinomial or Bernoulli */ - def train(input: RDD[LabeledPoint], lambda: Double, model: String): NaiveBayesModel = { - new NaiveBayes(lambda, NaiveBayesModels.withName(model)).run(input) + def train(input: RDD[LabeledPoint], lambda: Double, modelType: String): NaiveBayesModel = { + new NaiveBayes(lambda, NaiveBayesModels.withName(modelType)).run(input) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala index dd514ff8a37f2..cfc7f868a02f0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala @@ -17,11 +17,13 @@ package org.apache.spark.mllib.classification +import org.apache.spark.SparkContext import org.apache.spark.annotation.Experimental +import org.apache.spark.mllib.classification.impl.GLMClassificationModel import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.optimization._ import org.apache.spark.mllib.regression._ -import org.apache.spark.mllib.util.DataValidators +import org.apache.spark.mllib.util.{DataValidators, Loader, Saveable} import org.apache.spark.rdd.RDD /** @@ -33,7 +35,8 @@ import org.apache.spark.rdd.RDD class SVMModel ( override val weights: Vector, override val intercept: Double) - extends GeneralizedLinearModel(weights, intercept) with ClassificationModel with Serializable { + extends GeneralizedLinearModel(weights, intercept) with ClassificationModel with Serializable + with Saveable { private var threshold: Option[Double] = Some(0.0) @@ -49,6 +52,13 @@ class SVMModel ( this } + /** + * :: Experimental :: + * Returns the threshold (if any) used for converting raw prediction scores into 0/1 predictions. + */ + @Experimental + def getThreshold: Option[Double] = threshold + /** * :: Experimental :: * Clears the threshold so that `predict` will output raw prediction scores. @@ -69,6 +79,41 @@ class SVMModel ( case None => margin } } + + override def save(sc: SparkContext, path: String): Unit = { + GLMClassificationModel.SaveLoadV1_0.save(sc, path, this.getClass.getName, + numFeatures = weights.size, numClasses = 2, weights, intercept, threshold) + } + + override protected def formatVersion: String = "1.0" +} + +object SVMModel extends Loader[SVMModel] { + + override def load(sc: SparkContext, path: String): SVMModel = { + val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path) + // Hard-code class name string in case it changes in the future + val classNameV1_0 = "org.apache.spark.mllib.classification.SVMModel" + (loadedClassName, version) match { + case (className, "1.0") if className == classNameV1_0 => + val (numFeatures, numClasses) = ClassificationModel.getNumFeaturesClasses(metadata) + val data = GLMClassificationModel.SaveLoadV1_0.loadData(sc, path, classNameV1_0) + val model = new SVMModel(data.weights, data.intercept) + assert(model.weights.size == numFeatures, s"SVMModel.load with numFeatures=$numFeatures" + + s" was given non-matching weights vector of size ${model.weights.size}") + assert(numClasses == 2, + s"SVMModel.load was given numClasses=$numClasses but only supports 2 classes") + data.threshold match { + case Some(t) => model.setThreshold(t) + case None => model.clearThreshold() + } + model + case _ => throw new Exception( + s"SVMModel.load did not recognize model with (className, format version):" + + s"($loadedClassName, $version). Supported:\n" + + s" ($classNameV1_0, 1.0)") + } + } } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionWithSGD.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionWithSGD.scala new file mode 100644 index 0000000000000..b89f38cf5aba4 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionWithSGD.scala @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.classification + +import org.apache.spark.annotation.Experimental +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.regression.StreamingLinearAlgorithm + +/** + * :: Experimental :: + * Train or predict a logistic regression model on streaming data. Training uses + * Stochastic Gradient Descent to update the model based on each new batch of + * incoming data from a DStream (see `LogisticRegressionWithSGD` for model equation) + * + * Each batch of data is assumed to be an RDD of LabeledPoints. + * The number of data points per batch can vary, but the number + * of features must be constant. An initial weight + * vector must be provided. + * + * Use a builder pattern to construct a streaming logistic regression + * analysis in an application, like: + * + * {{{ + * val model = new StreamingLogisticRegressionWithSGD() + * .setStepSize(0.5) + * .setNumIterations(10) + * .setInitialWeights(Vectors.dense(...)) + * .trainOn(DStream) + * }}} + */ +@Experimental +class StreamingLogisticRegressionWithSGD private[mllib] ( + private var stepSize: Double, + private var numIterations: Int, + private var miniBatchFraction: Double, + private var regParam: Double) + extends StreamingLinearAlgorithm[LogisticRegressionModel, LogisticRegressionWithSGD] + with Serializable { + + /** + * Construct a StreamingLogisticRegression object with default parameters: + * {stepSize: 0.1, numIterations: 50, miniBatchFraction: 1.0, regParam: 0.0}. + * Initial weights must be set before using trainOn or predictOn + * (see `StreamingLinearAlgorithm`) + */ + def this() = this(0.1, 50, 1.0, 0.0) + + protected val algorithm = new LogisticRegressionWithSGD( + stepSize, numIterations, regParam, miniBatchFraction) + + /** Set the step size for gradient descent. Default: 0.1. */ + def setStepSize(stepSize: Double): this.type = { + this.algorithm.optimizer.setStepSize(stepSize) + this + } + + /** Set the number of iterations of gradient descent to run per update. Default: 50. */ + def setNumIterations(numIterations: Int): this.type = { + this.algorithm.optimizer.setNumIterations(numIterations) + this + } + + /** Set the fraction of each batch to use for updates. Default: 1.0. */ + def setMiniBatchFraction(miniBatchFraction: Double): this.type = { + this.algorithm.optimizer.setMiniBatchFraction(miniBatchFraction) + this + } + + /** Set the regularization parameter. Default: 0.0. */ + def setRegParam(regParam: Double): this.type = { + this.algorithm.optimizer.setRegParam(regParam) + this + } + + /** Set the initial weights. Default: [0.0, 0.0]. */ + def setInitialWeights(initialWeights: Vector): this.type = { + this.model = Some(algorithm.createModel(initialWeights, 0.0)) + this + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala new file mode 100644 index 0000000000000..8956189ff1158 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.classification.impl + +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods._ + +import org.apache.spark.SparkContext +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.util.Loader +import org.apache.spark.sql.{Row, SQLContext} + +/** + * Helper class for import/export of GLM classification models. + */ +private[classification] object GLMClassificationModel { + + object SaveLoadV1_0 { + + def thisFormatVersion = "1.0" + + /** Model data for import/export */ + case class Data(weights: Vector, intercept: Double, threshold: Option[Double]) + + /** + * Helper method for saving GLM classification model metadata and data. + * @param modelClass String name for model class, to be saved with metadata + * @param numClasses Number of classes label can take, to be saved with metadata + */ + def save( + sc: SparkContext, + path: String, + modelClass: String, + numFeatures: Int, + numClasses: Int, + weights: Vector, + intercept: Double, + threshold: Option[Double]): Unit = { + val sqlContext = new SQLContext(sc) + import sqlContext.implicits._ + + // Create JSON metadata. + val metadata = compact(render( + ("class" -> modelClass) ~ ("version" -> thisFormatVersion) ~ + ("numFeatures" -> numFeatures) ~ ("numClasses" -> numClasses))) + sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path)) + + // Create Parquet data. + val data = Data(weights, intercept, threshold) + sc.parallelize(Seq(data), 1).toDF().saveAsParquetFile(Loader.dataPath(path)) + } + + /** + * Helper method for loading GLM classification model data. + * + * NOTE: Callers of this method should check numClasses, numFeatures on their own. + * + * @param modelClass String name for model class (used for error messages) + */ + def loadData(sc: SparkContext, path: String, modelClass: String): Data = { + val datapath = Loader.dataPath(path) + val sqlContext = new SQLContext(sc) + val dataRDD = sqlContext.parquetFile(datapath) + val dataArray = dataRDD.select("weights", "intercept", "threshold").take(1) + assert(dataArray.size == 1, s"Unable to load $modelClass data from: $datapath") + val data = dataArray(0) + assert(data.size == 3, s"Unable to load $modelClass data from: $datapath") + val (weights, intercept) = data match { + case Row(weights: Vector, intercept: Double, _) => + (weights, intercept) + } + val threshold = if (data.isNullAt(2)) { + None + } else { + Some(data.getDouble(2)) + } + Data(weights, intercept, threshold) + } + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureEM.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala similarity index 83% rename from mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureEM.scala rename to mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala index d8e134619411b..568b65305649f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureEM.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala @@ -19,15 +19,18 @@ package org.apache.spark.mllib.clustering import scala.collection.mutable.IndexedSeq -import breeze.linalg.{DenseVector => BreezeVector, DenseMatrix => BreezeMatrix, diag, Transpose} +import breeze.linalg.{diag, DenseMatrix => BreezeMatrix, DenseVector => BDV, Vector => BV} -import org.apache.spark.mllib.linalg.{Matrices, Vector, Vectors, DenseVector, DenseMatrix, BLAS} +import org.apache.spark.annotation.Experimental +import org.apache.spark.mllib.linalg.{BLAS, DenseMatrix, Matrices, Vector, Vectors} import org.apache.spark.mllib.stat.distribution.MultivariateGaussian import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD import org.apache.spark.util.Utils /** + * :: Experimental :: + * * This class performs expectation maximization for multivariate Gaussian * Mixture Models (GMMs). A GMM represents a composite distribution of * independent Gaussian distributions with associated "mixing" weights @@ -38,19 +41,27 @@ import org.apache.spark.util.Utils * less than convergenceTol, or until it has reached the max number of iterations. * While this process is generally guaranteed to converge, it is not guaranteed * to find a global optimum. - * + * + * Note: For high-dimensional data (with many features), this algorithm may perform poorly. + * This is due to high-dimensional data (a) making it difficult to cluster at all (based + * on statistical/theoretical arguments) and (b) numerical issues with Gaussian distributions. + * * @param k The number of independent Gaussians in the mixture model * @param convergenceTol The maximum change in log-likelihood at which convergence * is considered to have occurred. * @param maxIterations The maximum number of iterations to perform */ -class GaussianMixtureEM private ( +@Experimental +class GaussianMixture private ( private var k: Int, private var convergenceTol: Double, private var maxIterations: Int, private var seed: Long) extends Serializable { - /** A default instance, 2 Gaussians, 100 iterations, 0.01 log-likelihood threshold */ + /** + * Constructs a default instance. The default parameters are {k: 2, convergenceTol: 0.01, + * maxIterations: 100, seed: random}. + */ def this() = this(2, 0.01, 100, Utils.random.nextLong()) // number of samples per cluster to use when initializing Gaussians @@ -123,7 +134,7 @@ class GaussianMixtureEM private ( val sc = data.sparkContext // we will operate on the data as breeze data - val breezeData = data.map(u => u.toBreeze.toDenseVector).cache() + val breezeData = data.map(_.toBreeze).cache() // Get length of the input vectors val d = breezeData.first().length @@ -134,16 +145,14 @@ class GaussianMixtureEM private ( // diagonal covariance matrices using component variances // derived from the samples val (weights, gaussians) = initialModel match { - case Some(gmm) => (gmm.weight, gmm.mu.zip(gmm.sigma).map { case(mu, sigma) => - new MultivariateGaussian(mu, sigma) - }) + case Some(gmm) => (gmm.weights, gmm.gaussians) case None => { val samples = breezeData.takeSample(withReplacement = true, k * nSamples, seed) (Array.fill(k)(1.0 / k), Array.tabulate(k) { i => val slice = samples.view(i * nSamples, (i + 1) * nSamples) new MultivariateGaussian(vectorMean(slice), initCovariance(slice)) - }) + }) } } @@ -164,7 +173,7 @@ class GaussianMixtureEM private ( var i = 0 while (i < k) { val mu = sums.means(i) / sums.weights(i) - BLAS.syr(-sums.weights(i), Vectors.fromBreeze(mu).asInstanceOf[DenseVector], + BLAS.syr(-sums.weights(i), Vectors.fromBreeze(mu), Matrices.fromBreeze(sums.sigmas(i)).asInstanceOf[DenseMatrix]) weights(i) = sums.weights(i) / sumWeights gaussians(i) = new MultivariateGaussian(mu, sums.sigmas(i) / sums.weights(i)) @@ -176,15 +185,12 @@ class GaussianMixtureEM private ( iter += 1 } - // Need to convert the breeze matrices to MLlib matrices - val means = Array.tabulate(k) { i => gaussians(i).mu } - val sigmas = Array.tabulate(k) { i => gaussians(i).sigma } - new GaussianMixtureModel(weights, means, sigmas) + new GaussianMixtureModel(weights, gaussians) } /** Average of dense breeze vectors */ - private def vectorMean(x: IndexedSeq[BreezeVector[Double]]): BreezeVector[Double] = { - val v = BreezeVector.zeros[Double](x(0).length) + private def vectorMean(x: IndexedSeq[BV[Double]]): BDV[Double] = { + val v = BDV.zeros[Double](x(0).length) x.foreach(xi => v += xi) v / x.length.toDouble } @@ -193,10 +199,10 @@ class GaussianMixtureEM private ( * Construct matrix where diagonal entries are element-wise * variance of input vectors (computes biased variance) */ - private def initCovariance(x: IndexedSeq[BreezeVector[Double]]): BreezeMatrix[Double] = { + private def initCovariance(x: IndexedSeq[BV[Double]]): BreezeMatrix[Double] = { val mu = vectorMean(x) - val ss = BreezeVector.zeros[Double](x(0).length) - x.map(xi => (xi - mu) :^ 2.0).foreach(u => ss += u) + val ss = BDV.zeros[Double](x(0).length) + x.foreach(xi => ss += (xi - mu) :^ 2.0) diag(ss / x.length.toDouble) } } @@ -205,7 +211,7 @@ class GaussianMixtureEM private ( private object ExpectationSum { def zero(k: Int, d: Int): ExpectationSum = { new ExpectationSum(0.0, Array.fill(k)(0.0), - Array.fill(k)(BreezeVector.zeros(d)), Array.fill(k)(BreezeMatrix.zeros(d,d))) + Array.fill(k)(BDV.zeros(d)), Array.fill(k)(BreezeMatrix.zeros(d,d))) } // compute cluster contributions for each input point @@ -213,19 +219,18 @@ private object ExpectationSum { def add( weights: Array[Double], dists: Array[MultivariateGaussian]) - (sums: ExpectationSum, x: BreezeVector[Double]): ExpectationSum = { + (sums: ExpectationSum, x: BV[Double]): ExpectationSum = { val p = weights.zip(dists).map { case (weight, dist) => MLUtils.EPSILON + weight * dist.pdf(x) } val pSum = p.sum sums.logLikelihood += math.log(pSum) - val xxt = x * new Transpose(x) var i = 0 while (i < sums.k) { p(i) /= pSum sums.weights(i) += p(i) sums.means(i) += x * p(i) - BLAS.syr(p(i), Vectors.fromBreeze(x).asInstanceOf[DenseVector], + BLAS.syr(p(i), Vectors.fromBreeze(x), Matrices.fromBreeze(sums.sigmas(i)).asInstanceOf[DenseMatrix]) i = i + 1 } @@ -237,7 +242,7 @@ private object ExpectationSum { private class ExpectationSum( var logLikelihood: Double, val weights: Array[Double], - val means: Array[BreezeVector[Double]], + val means: Array[BDV[Double]], val sigmas: Array[BreezeMatrix[Double]]) extends Serializable { val k = weights.length diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala index 416cad080c408..af6f83c74bb40 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala @@ -19,12 +19,15 @@ package org.apache.spark.mllib.clustering import breeze.linalg.{DenseVector => BreezeVector} -import org.apache.spark.rdd.RDD -import org.apache.spark.mllib.linalg.{Matrix, Vector} +import org.apache.spark.annotation.Experimental +import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.stat.distribution.MultivariateGaussian import org.apache.spark.mllib.util.MLUtils +import org.apache.spark.rdd.RDD /** + * :: Experimental :: + * * Multivariate Gaussian Mixture Model (GMM) consisting of k Gaussians, where points * are drawn from each Gaussian i=1..k with probability w(i); mu(i) and sigma(i) are * the respective mean and covariance for each Gaussian distribution i=1..k. @@ -35,13 +38,15 @@ import org.apache.spark.mllib.util.MLUtils * @param sigma Covariance maxtrix for each Gaussian in the mixture, where sigma(i) is the * covariance matrix for Gaussian i */ +@Experimental class GaussianMixtureModel( - val weight: Array[Double], - val mu: Array[Vector], - val sigma: Array[Matrix]) extends Serializable { + val weights: Array[Double], + val gaussians: Array[MultivariateGaussian]) extends Serializable { + + require(weights.length == gaussians.length, "Length of weight and Gaussian arrays must match") /** Number of gaussians in mixture */ - def k: Int = weight.length + def k: Int = weights.length /** Maps given points to their cluster indices. */ def predict(points: RDD[Vector]): RDD[Int] = { @@ -55,14 +60,10 @@ class GaussianMixtureModel( */ def predictSoft(points: RDD[Vector]): RDD[Array[Double]] = { val sc = points.sparkContext - val dists = sc.broadcast { - (0 until k).map { i => - new MultivariateGaussian(mu(i).toBreeze.toDenseVector, sigma(i).toBreeze.toDenseMatrix) - }.toArray - } - val weights = sc.broadcast(weight) + val bcDists = sc.broadcast(gaussians) + val bcWeights = sc.broadcast(weights) points.map { x => - computeSoftAssignments(x.toBreeze.toDenseVector, dists.value, weights.value, k) + computeSoftAssignments(x.toBreeze.toDenseVector, bcDists.value, bcWeights.value, k) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala index 54c301d3e9e14..11633e8242313 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala @@ -19,14 +19,14 @@ package org.apache.spark.mllib.clustering import scala.collection.mutable.ArrayBuffer -import org.apache.spark.annotation.Experimental import org.apache.spark.Logging -import org.apache.spark.SparkContext._ +import org.apache.spark.annotation.Experimental import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.linalg.BLAS.{axpy, scal} import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel +import org.apache.spark.util.Utils import org.apache.spark.util.random.XORShiftRandom /** @@ -43,13 +43,14 @@ class KMeans private ( private var runs: Int, private var initializationMode: String, private var initializationSteps: Int, - private var epsilon: Double) extends Serializable with Logging { + private var epsilon: Double, + private var seed: Long) extends Serializable with Logging { /** * Constructs a KMeans instance with default parameters: {k: 2, maxIterations: 20, runs: 1, - * initializationMode: "k-means||", initializationSteps: 5, epsilon: 1e-4}. + * initializationMode: "k-means||", initializationSteps: 5, epsilon: 1e-4, seed: random}. */ - def this() = this(2, 20, 1, KMeans.K_MEANS_PARALLEL, 5, 1e-4) + def this() = this(2, 20, 1, KMeans.K_MEANS_PARALLEL, 5, 1e-4, Utils.random.nextLong()) /** Set the number of clusters to create (k). Default: 2. */ def setK(k: Int): this.type = { @@ -112,6 +113,12 @@ class KMeans private ( this } + /** Set the random seed for cluster initialization. */ + def setSeed(seed: Long): this.type = { + this.seed = seed + this + } + /** * Train a K-means model on the given set of points; `data` should be cached for high * performance, because this is an iterative algorithm. @@ -255,7 +262,7 @@ class KMeans private ( private def initRandom(data: RDD[VectorWithNorm]) : Array[Array[VectorWithNorm]] = { // Sample all the cluster centers in one pass to avoid repeated scans - val sample = data.takeSample(true, runs * k, new XORShiftRandom().nextInt()).toSeq + val sample = data.takeSample(true, runs * k, new XORShiftRandom(this.seed).nextInt()).toSeq Array.tabulate(runs)(r => sample.slice(r * k, (r + 1) * k).map { v => new VectorWithNorm(Vectors.dense(v.vector.toArray), v.norm) }.toArray) @@ -272,45 +279,81 @@ class KMeans private ( */ private def initKMeansParallel(data: RDD[VectorWithNorm]) : Array[Array[VectorWithNorm]] = { - // Initialize each run's center to a random point - val seed = new XORShiftRandom().nextInt() + // Initialize empty centers and point costs. + val centers = Array.tabulate(runs)(r => ArrayBuffer.empty[VectorWithNorm]) + var costs = data.map(_ => Vectors.dense(Array.fill(runs)(Double.PositiveInfinity))).cache() + + // Initialize each run's first center to a random point. + val seed = new XORShiftRandom(this.seed).nextInt() val sample = data.takeSample(true, runs, seed).toSeq - val centers = Array.tabulate(runs)(r => ArrayBuffer(sample(r).toDense)) + val newCenters = Array.tabulate(runs)(r => ArrayBuffer(sample(r).toDense)) + + /** Merges new centers to centers. */ + def mergeNewCenters(): Unit = { + var r = 0 + while (r < runs) { + centers(r) ++= newCenters(r) + newCenters(r).clear() + r += 1 + } + } // On each step, sample 2 * k points on average for each run with probability proportional - // to their squared distance from that run's current centers + // to their squared distance from that run's centers. Note that only distances between points + // and new centers are computed in each iteration. var step = 0 while (step < initializationSteps) { - val bcCenters = data.context.broadcast(centers) - val sumCosts = data.flatMap { point => - (0 until runs).map { r => - (r, KMeans.pointCost(bcCenters.value(r), point)) - } - }.reduceByKey(_ + _).collectAsMap() - val chosen = data.mapPartitionsWithIndex { (index, points) => + val bcNewCenters = data.context.broadcast(newCenters) + val preCosts = costs + costs = data.zip(preCosts).map { case (point, cost) => + Vectors.dense( + Array.tabulate(runs) { r => + math.min(KMeans.pointCost(bcNewCenters.value(r), point), cost(r)) + }) + }.cache() + val sumCosts = costs + .aggregate(Vectors.zeros(runs))( + seqOp = (s, v) => { + // s += v + axpy(1.0, v, s) + s + }, + combOp = (s0, s1) => { + // s0 += s1 + axpy(1.0, s1, s0) + s0 + } + ) + preCosts.unpersist(blocking = false) + val chosen = data.zip(costs).mapPartitionsWithIndex { (index, pointsWithCosts) => val rand = new XORShiftRandom(seed ^ (step << 16) ^ index) - points.flatMap { p => - (0 until runs).filter { r => - rand.nextDouble() < 2.0 * KMeans.pointCost(bcCenters.value(r), p) * k / sumCosts(r) - }.map((_, p)) + pointsWithCosts.flatMap { case (p, c) => + val rs = (0 until runs).filter { r => + rand.nextDouble() < 2.0 * c(r) * k / sumCosts(r) + } + if (rs.length > 0) Some(p, rs) else None } }.collect() - chosen.foreach { case (r, p) => - centers(r) += p.toDense + mergeNewCenters() + chosen.foreach { case (p, rs) => + rs.foreach(newCenters(_) += p.toDense) } step += 1 } + mergeNewCenters() + costs.unpersist(blocking = false) + // Finally, we might have a set of more than k candidate centers for each run; weigh each // candidate by the number of points in the dataset mapping to it and run a local k-means++ // on the weighted centers to pick just k of them val bcCenters = data.context.broadcast(centers) val weightMap = data.flatMap { p => - (0 until runs).map { r => + Iterator.tabulate(runs) { r => ((r, KMeans.findClosest(bcCenters.value(r), p)._1), 1.0) } }.reduceByKey(_ + _).collectAsMap() - val finalCenters = (0 until runs).map { r => + val finalCenters = (0 until runs).par.map { r => val myCenters = centers(r).toArray val myWeights = (0 until myCenters.length).map(i => weightMap.getOrElse((r, i), 0.0)).toArray LocalKMeans.kMeansPlusPlus(r, myCenters, myWeights, k, 30) @@ -333,7 +376,32 @@ object KMeans { /** * Trains a k-means model using the given set of parameters. * - * @param data training points stored as `RDD[Array[Double]]` + * @param data training points stored as `RDD[Vector]` + * @param k number of clusters + * @param maxIterations max number of iterations + * @param runs number of parallel runs, defaults to 1. The best model is returned. + * @param initializationMode initialization model, either "random" or "k-means||" (default). + * @param seed random seed value for cluster initialization + */ + def train( + data: RDD[Vector], + k: Int, + maxIterations: Int, + runs: Int, + initializationMode: String, + seed: Long): KMeansModel = { + new KMeans().setK(k) + .setMaxIterations(maxIterations) + .setRuns(runs) + .setInitializationMode(initializationMode) + .setSeed(seed) + .run(data) + } + + /** + * Trains a k-means model using the given set of parameters. + * + * @param data training points stored as `RDD[Vector]` * @param k number of clusters * @param maxIterations max number of iterations * @param runs number of parallel runs, defaults to 1. The best model is returned. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala new file mode 100644 index 0000000000000..5e17c8da61134 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala @@ -0,0 +1,475 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.clustering + +import java.util.Random + +import breeze.linalg.{DenseVector => BDV, normalize, axpy => brzAxpy} + +import org.apache.spark.Logging +import org.apache.spark.annotation.Experimental +import org.apache.spark.api.java.JavaPairRDD +import org.apache.spark.graphx._ +import org.apache.spark.graphx.impl.GraphImpl +import org.apache.spark.mllib.impl.PeriodicGraphCheckpointer +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.rdd.RDD +import org.apache.spark.util.Utils + + +/** + * :: Experimental :: + * + * Latent Dirichlet Allocation (LDA), a topic model designed for text documents. + * + * Terminology: + * - "word" = "term": an element of the vocabulary + * - "token": instance of a term appearing in a document + * - "topic": multinomial distribution over words representing some concept + * + * Currently, the underlying implementation uses Expectation-Maximization (EM), implemented + * according to the Asuncion et al. (2009) paper referenced below. + * + * References: + * - Original LDA paper (journal version): + * Blei, Ng, and Jordan. "Latent Dirichlet Allocation." JMLR, 2003. + * - This class implements their "smoothed" LDA model. + * - Paper which clearly explains several algorithms, including EM: + * Asuncion, Welling, Smyth, and Teh. + * "On Smoothing and Inference for Topic Models." UAI, 2009. + * + * @see [[http://en.wikipedia.org/wiki/Latent_Dirichlet_allocation Latent Dirichlet allocation + * (Wikipedia)]] + */ +@Experimental +class LDA private ( + private var k: Int, + private var maxIterations: Int, + private var docConcentration: Double, + private var topicConcentration: Double, + private var seed: Long, + private var checkpointInterval: Int) extends Logging { + + def this() = this(k = 10, maxIterations = 20, docConcentration = -1, topicConcentration = -1, + seed = Utils.random.nextLong(), checkpointInterval = 10) + + /** + * Number of topics to infer. I.e., the number of soft cluster centers. + */ + def getK: Int = k + + /** + * Number of topics to infer. I.e., the number of soft cluster centers. + * (default = 10) + */ + def setK(k: Int): this.type = { + require(k > 0, s"LDA k (number of clusters) must be > 0, but was set to $k") + this.k = k + this + } + + /** + * Concentration parameter (commonly named "alpha") for the prior placed on documents' + * distributions over topics ("theta"). + * + * This is the parameter to a symmetric Dirichlet distribution. + */ + def getDocConcentration: Double = { + if (this.docConcentration == -1) { + (50.0 / k) + 1.0 + } else { + this.docConcentration + } + } + + /** + * Concentration parameter (commonly named "alpha") for the prior placed on documents' + * distributions over topics ("theta"). + * + * This is the parameter to a symmetric Dirichlet distribution. + * + * This value should be > 1.0, where larger values mean more smoothing (more regularization). + * If set to -1, then docConcentration is set automatically. + * (default = -1 = automatic) + * + * Automatic setting of parameter: + * - For EM: default = (50 / k) + 1. + * - The 50/k is common in LDA libraries. + * - The +1 follows Asuncion et al. (2009), who recommend a +1 adjustment for EM. + * + * Note: The restriction > 1.0 may be relaxed in the future (allowing sparse solutions), + * but values in (0,1) are not yet supported. + */ + def setDocConcentration(docConcentration: Double): this.type = { + require(docConcentration > 1.0 || docConcentration == -1.0, + s"LDA docConcentration must be > 1.0 (or -1 for auto), but was set to $docConcentration") + this.docConcentration = docConcentration + this + } + + /** Alias for [[getDocConcentration]] */ + def getAlpha: Double = getDocConcentration + + /** Alias for [[setDocConcentration()]] */ + def setAlpha(alpha: Double): this.type = setDocConcentration(alpha) + + /** + * Concentration parameter (commonly named "beta" or "eta") for the prior placed on topics' + * distributions over terms. + * + * This is the parameter to a symmetric Dirichlet distribution. + * + * Note: The topics' distributions over terms are called "beta" in the original LDA paper + * by Blei et al., but are called "phi" in many later papers such as Asuncion et al., 2009. + */ + def getTopicConcentration: Double = { + if (this.topicConcentration == -1) { + 1.1 + } else { + this.topicConcentration + } + } + + /** + * Concentration parameter (commonly named "beta" or "eta") for the prior placed on topics' + * distributions over terms. + * + * This is the parameter to a symmetric Dirichlet distribution. + * + * Note: The topics' distributions over terms are called "beta" in the original LDA paper + * by Blei et al., but are called "phi" in many later papers such as Asuncion et al., 2009. + * + * This value should be > 0.0. + * If set to -1, then topicConcentration is set automatically. + * (default = -1 = automatic) + * + * Automatic setting of parameter: + * - For EM: default = 0.1 + 1. + * - The 0.1 gives a small amount of smoothing. + * - The +1 follows Asuncion et al. (2009), who recommend a +1 adjustment for EM. + * + * Note: The restriction > 1.0 may be relaxed in the future (allowing sparse solutions), + * but values in (0,1) are not yet supported. + */ + def setTopicConcentration(topicConcentration: Double): this.type = { + require(topicConcentration > 1.0 || topicConcentration == -1.0, + s"LDA topicConcentration must be > 1.0 (or -1 for auto), but was set to $topicConcentration") + this.topicConcentration = topicConcentration + this + } + + /** Alias for [[getTopicConcentration]] */ + def getBeta: Double = getTopicConcentration + + /** Alias for [[setTopicConcentration()]] */ + def setBeta(beta: Double): this.type = setBeta(beta) + + /** + * Maximum number of iterations for learning. + */ + def getMaxIterations: Int = maxIterations + + /** + * Maximum number of iterations for learning. + * (default = 20) + */ + def setMaxIterations(maxIterations: Int): this.type = { + this.maxIterations = maxIterations + this + } + + /** Random seed */ + def getSeed: Long = seed + + /** Random seed */ + def setSeed(seed: Long): this.type = { + this.seed = seed + this + } + + /** + * Period (in iterations) between checkpoints. + */ + def getCheckpointInterval: Int = checkpointInterval + + /** + * Period (in iterations) between checkpoints (default = 10). Checkpointing helps with recovery + * (when nodes fail). It also helps with eliminating temporary shuffle files on disk, which can be + * important when LDA is run for many iterations. If the checkpoint directory is not set in + * [[org.apache.spark.SparkContext]], this setting is ignored. + * + * @see [[org.apache.spark.SparkContext#setCheckpointDir]] + */ + def setCheckpointInterval(checkpointInterval: Int): this.type = { + this.checkpointInterval = checkpointInterval + this + } + + /** + * Learn an LDA model using the given dataset. + * + * @param documents RDD of documents, which are term (word) count vectors paired with IDs. + * The term count vectors are "bags of words" with a fixed-size vocabulary + * (where the vocabulary size is the length of the vector). + * Document IDs must be unique and >= 0. + * @return Inferred LDA model + */ + def run(documents: RDD[(Long, Vector)]): DistributedLDAModel = { + val state = LDA.initialState(documents, k, getDocConcentration, getTopicConcentration, seed, + checkpointInterval) + var iter = 0 + val iterationTimes = Array.fill[Double](maxIterations)(0) + while (iter < maxIterations) { + val start = System.nanoTime() + state.next() + val elapsedSeconds = (System.nanoTime() - start) / 1e9 + iterationTimes(iter) = elapsedSeconds + iter += 1 + } + state.graphCheckpointer.deleteAllCheckpoints() + new DistributedLDAModel(state, iterationTimes) + } + + /** Java-friendly version of [[run()]] */ + def run(documents: JavaPairRDD[java.lang.Long, Vector]): DistributedLDAModel = { + run(documents.rdd.asInstanceOf[RDD[(Long, Vector)]]) + } +} + + +private[clustering] object LDA { + + /* + DEVELOPERS NOTE: + + This implementation uses GraphX, where the graph is bipartite with 2 types of vertices: + - Document vertices + - indexed with unique indices >= 0 + - Store vectors of length k (# topics). + - Term vertices + - indexed {-1, -2, ..., -vocabSize} + - Store vectors of length k (# topics). + - Edges correspond to terms appearing in documents. + - Edges are directed Document -> Term. + - Edges are partitioned by documents. + + Info on EM implementation. + - We follow Section 2.2 from Asuncion et al., 2009. We use some of their notation. + - In this implementation, there is one edge for every unique term appearing in a document, + i.e., for every unique (document, term) pair. + - Notation: + - N_{wkj} = count of tokens of term w currently assigned to topic k in document j + - N_{*} where * is missing a subscript w/k/j is the count summed over missing subscript(s) + - gamma_{wjk} = P(z_i = k | x_i = w, d_i = j), + the probability of term x_i in document d_i having topic z_i. + - Data graph + - Document vertices store N_{kj} + - Term vertices store N_{wk} + - Edges store N_{wj}. + - Global data N_k + - Algorithm + - Initial state: + - Document and term vertices store random counts N_{wk}, N_{kj}. + - E-step: For each (document,term) pair i, compute P(z_i | x_i, d_i). + - Aggregate N_k from term vertices. + - Compute gamma_{wjk} for each possible topic k, from each triplet. + using inputs N_{wk}, N_{kj}, N_k. + - M-step: Compute sufficient statistics for hidden parameters phi and theta + (counts N_{wk}, N_{kj}, N_k). + - Document update: + - N_{kj} <- sum_w N_{wj} gamma_{wjk} + - N_j <- sum_k N_{kj} (only needed to output predictions) + - Term update: + - N_{wk} <- sum_j N_{wj} gamma_{wjk} + - N_k <- sum_w N_{wk} + + TODO: Add simplex constraints to allow alpha in (0,1). + See: Vorontsov and Potapenko. "Tutorial on Probabilistic Topic Modeling : Additive + Regularization for Stochastic Matrix Factorization." 2014. + */ + + /** + * Vector over topics (length k) of token counts. + * The meaning of these counts can vary, and it may or may not be normalized to be a distribution. + */ + private[clustering] type TopicCounts = BDV[Double] + + private[clustering] type TokenCount = Double + + /** Term vertex IDs are {-1, -2, ..., -vocabSize} */ + private[clustering] def term2index(term: Int): Long = -(1 + term.toLong) + + private[clustering] def index2term(termIndex: Long): Int = -(1 + termIndex).toInt + + private[clustering] def isDocumentVertex(v: (VertexId, _)): Boolean = v._1 >= 0 + + private[clustering] def isTermVertex(v: (VertexId, _)): Boolean = v._1 < 0 + + /** + * Optimizer for EM algorithm which stores data + parameter graph, plus algorithm parameters. + * + * @param graph EM graph, storing current parameter estimates in vertex descriptors and + * data (token counts) in edge descriptors. + * @param k Number of topics + * @param vocabSize Number of unique terms + * @param docConcentration "alpha" + * @param topicConcentration "beta" or "eta" + */ + private[clustering] class EMOptimizer( + var graph: Graph[TopicCounts, TokenCount], + val k: Int, + val vocabSize: Int, + val docConcentration: Double, + val topicConcentration: Double, + checkpointInterval: Int) { + + private[LDA] val graphCheckpointer = new PeriodicGraphCheckpointer[TopicCounts, TokenCount]( + graph, checkpointInterval) + + def next(): EMOptimizer = { + val eta = topicConcentration + val W = vocabSize + val alpha = docConcentration + + val N_k = globalTopicTotals + val sendMsg: EdgeContext[TopicCounts, TokenCount, (Boolean, TopicCounts)] => Unit = + (edgeContext) => { + // Compute N_{wj} gamma_{wjk} + val N_wj = edgeContext.attr + // E-STEP: Compute gamma_{wjk} (smoothed topic distributions), scaled by token count + // N_{wj}. + val scaledTopicDistribution: TopicCounts = + computePTopic(edgeContext.srcAttr, edgeContext.dstAttr, N_k, W, eta, alpha) *= N_wj + edgeContext.sendToDst((false, scaledTopicDistribution)) + edgeContext.sendToSrc((false, scaledTopicDistribution)) + } + // This is a hack to detect whether we could modify the values in-place. + // TODO: Add zero/seqOp/combOp option to aggregateMessages. (SPARK-5438) + val mergeMsg: ((Boolean, TopicCounts), (Boolean, TopicCounts)) => (Boolean, TopicCounts) = + (m0, m1) => { + val sum = + if (m0._1) { + m0._2 += m1._2 + } else if (m1._1) { + m1._2 += m0._2 + } else { + m0._2 + m1._2 + } + (true, sum) + } + // M-STEP: Aggregation computes new N_{kj}, N_{wk} counts. + val docTopicDistributions: VertexRDD[TopicCounts] = + graph.aggregateMessages[(Boolean, TopicCounts)](sendMsg, mergeMsg) + .mapValues(_._2) + // Update the vertex descriptors with the new counts. + val newGraph = GraphImpl.fromExistingRDDs(docTopicDistributions, graph.edges) + graph = newGraph + graphCheckpointer.updateGraph(newGraph) + globalTopicTotals = computeGlobalTopicTotals() + this + } + + /** + * Aggregate distributions over topics from all term vertices. + * + * Note: This executes an action on the graph RDDs. + */ + var globalTopicTotals: TopicCounts = computeGlobalTopicTotals() + + private def computeGlobalTopicTotals(): TopicCounts = { + val numTopics = k + graph.vertices.filter(isTermVertex).values.fold(BDV.zeros[Double](numTopics))(_ += _) + } + + } + + /** + * Compute gamma_{wjk}, a distribution over topics k. + */ + private def computePTopic( + docTopicCounts: TopicCounts, + termTopicCounts: TopicCounts, + totalTopicCounts: TopicCounts, + vocabSize: Int, + eta: Double, + alpha: Double): TopicCounts = { + val K = docTopicCounts.length + val N_j = docTopicCounts.data + val N_w = termTopicCounts.data + val N = totalTopicCounts.data + val eta1 = eta - 1.0 + val alpha1 = alpha - 1.0 + val Weta1 = vocabSize * eta1 + var sum = 0.0 + val gamma_wj = new Array[Double](K) + var k = 0 + while (k < K) { + val gamma_wjk = (N_w(k) + eta1) * (N_j(k) + alpha1) / (N(k) + Weta1) + gamma_wj(k) = gamma_wjk + sum += gamma_wjk + k += 1 + } + // normalize + BDV(gamma_wj) /= sum + } + + /** + * Compute bipartite term/doc graph. + */ + private def initialState( + docs: RDD[(Long, Vector)], + k: Int, + docConcentration: Double, + topicConcentration: Double, + randomSeed: Long, + checkpointInterval: Int): EMOptimizer = { + // For each document, create an edge (Document -> Term) for each unique term in the document. + val edges: RDD[Edge[TokenCount]] = docs.flatMap { case (docID: Long, termCounts: Vector) => + // Add edges for terms with non-zero counts. + termCounts.toBreeze.activeIterator.filter(_._2 != 0.0).map { case (term, cnt) => + Edge(docID, term2index(term), cnt) + } + } + + val vocabSize = docs.take(1).head._2.size + + // Create vertices. + // Initially, we use random soft assignments of tokens to topics (random gamma). + def createVertices(): RDD[(VertexId, TopicCounts)] = { + val verticesTMP: RDD[(VertexId, TopicCounts)] = + edges.mapPartitionsWithIndex { case (partIndex, partEdges) => + val random = new Random(partIndex + randomSeed) + partEdges.flatMap { edge => + val gamma = normalize(BDV.fill[Double](k)(random.nextDouble()), 1.0) + val sum = gamma * edge.attr + Seq((edge.srcId, sum), (edge.dstId, sum)) + } + } + verticesTMP.reduceByKey(_ + _) + } + + val docTermVertices = createVertices() + + // Partition such that edges are grouped by document + val graph = Graph(docTermVertices, edges) + .partitionBy(PartitionStrategy.EdgePartition1D) + + new EMOptimizer(graph, k, vocabSize, docConcentration, topicConcentration, checkpointInterval) + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala new file mode 100644 index 0000000000000..b0e991d2f2344 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala @@ -0,0 +1,351 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.clustering + +import breeze.linalg.{DenseMatrix => BDM, normalize, sum => brzSum} + +import org.apache.spark.annotation.Experimental +import org.apache.spark.graphx.{VertexId, EdgeContext, Graph} +import org.apache.spark.mllib.linalg.{Vectors, Vector, Matrices, Matrix} +import org.apache.spark.rdd.RDD +import org.apache.spark.util.BoundedPriorityQueue + +/** + * :: Experimental :: + * + * Latent Dirichlet Allocation (LDA) model. + * + * This abstraction permits for different underlying representations, + * including local and distributed data structures. + */ +@Experimental +abstract class LDAModel private[clustering] { + + /** Number of topics */ + def k: Int + + /** Vocabulary size (number of terms or terms in the vocabulary) */ + def vocabSize: Int + + /** + * Inferred topics, where each topic is represented by a distribution over terms. + * This is a matrix of size vocabSize x k, where each column is a topic. + * No guarantees are given about the ordering of the topics. + */ + def topicsMatrix: Matrix + + /** + * Return the topics described by weighted terms. + * + * This limits the number of terms per topic. + * This is approximate; it may not return exactly the top-weighted terms for each topic. + * To get a more precise set of top terms, increase maxTermsPerTopic. + * + * @param maxTermsPerTopic Maximum number of terms to collect for each topic. + * @return Array over topics. Each topic is represented as a pair of matching arrays: + * (term indices, term weights in topic). + * Each topic's terms are sorted in order of decreasing weight. + */ + def describeTopics(maxTermsPerTopic: Int): Array[(Array[Int], Array[Double])] + + /** + * Return the topics described by weighted terms. + * + * WARNING: If vocabSize and k are large, this can return a large object! + * + * @return Array over topics. Each topic is represented as a pair of matching arrays: + * (term indices, term weights in topic). + * Each topic's terms are sorted in order of decreasing weight. + */ + def describeTopics(): Array[(Array[Int], Array[Double])] = describeTopics(vocabSize) + + /* TODO (once LDA can be trained with Strings or given a dictionary) + * Return the topics described by weighted terms. + * + * This is similar to [[describeTopics()]] but returns String values for terms. + * If this model was trained using Strings or was given a dictionary, then this method returns + * terms as text. Otherwise, this method returns terms as term indices. + * + * This limits the number of terms per topic. + * This is approximate; it may not return exactly the top-weighted terms for each topic. + * To get a more precise set of top terms, increase maxTermsPerTopic. + * + * @param maxTermsPerTopic Maximum number of terms to collect for each topic. + * @return Array over topics. Each topic is represented as a pair of matching arrays: + * (terms, term weights in topic) where terms are either the actual term text + * (if available) or the term indices. + * Each topic's terms are sorted in order of decreasing weight. + */ + // def describeTopicsAsStrings(maxTermsPerTopic: Int): Array[(Array[Double], Array[String])] + + /* TODO (once LDA can be trained with Strings or given a dictionary) + * Return the topics described by weighted terms. + * + * This is similar to [[describeTopics()]] but returns String values for terms. + * If this model was trained using Strings or was given a dictionary, then this method returns + * terms as text. Otherwise, this method returns terms as term indices. + * + * WARNING: If vocabSize and k are large, this can return a large object! + * + * @return Array over topics. Each topic is represented as a pair of matching arrays: + * (terms, term weights in topic) where terms are either the actual term text + * (if available) or the term indices. + * Each topic's terms are sorted in order of decreasing weight. + */ + // def describeTopicsAsStrings(): Array[(Array[Double], Array[String])] = + // describeTopicsAsStrings(vocabSize) + + /* TODO + * Compute the log likelihood of the observed tokens, given the current parameter estimates: + * log P(docs | topics, topic distributions for docs, alpha, eta) + * + * Note: + * - This excludes the prior. + * - Even with the prior, this is NOT the same as the data log likelihood given the + * hyperparameters. + * + * @param documents RDD of documents, which are term (word) count vectors paired with IDs. + * The term count vectors are "bags of words" with a fixed-size vocabulary + * (where the vocabulary size is the length of the vector). + * This must use the same vocabulary (ordering of term counts) as in training. + * Document IDs must be unique and >= 0. + * @return Estimated log likelihood of the data under this model + */ + // def logLikelihood(documents: RDD[(Long, Vector)]): Double + + /* TODO + * Compute the estimated topic distribution for each document. + * This is often called “theta” in the literature. + * + * @param documents RDD of documents, which are term (word) count vectors paired with IDs. + * The term count vectors are "bags of words" with a fixed-size vocabulary + * (where the vocabulary size is the length of the vector). + * This must use the same vocabulary (ordering of term counts) as in training. + * Document IDs must be unique and >= 0. + * @return Estimated topic distribution for each document. + * The returned RDD may be zipped with the given RDD, where each returned vector + * is a multinomial distribution over topics. + */ + // def topicDistributions(documents: RDD[(Long, Vector)]): RDD[(Long, Vector)] + +} + +/** + * :: Experimental :: + * + * Local LDA model. + * This model stores only the inferred topics. + * It may be used for computing topics for new documents, but it may give less accurate answers + * than the [[DistributedLDAModel]]. + * + * @param topics Inferred topics (vocabSize x k matrix). + */ +@Experimental +class LocalLDAModel private[clustering] ( + private val topics: Matrix) extends LDAModel with Serializable { + + override def k: Int = topics.numCols + + override def vocabSize: Int = topics.numRows + + override def topicsMatrix: Matrix = topics + + override def describeTopics(maxTermsPerTopic: Int): Array[(Array[Int], Array[Double])] = { + val brzTopics = topics.toBreeze.toDenseMatrix + Range(0, k).map { topicIndex => + val topic = normalize(brzTopics(::, topicIndex), 1.0) + val (termWeights, terms) = + topic.toArray.zipWithIndex.sortBy(-_._1).take(maxTermsPerTopic).unzip + (terms.toArray, termWeights.toArray) + }.toArray + } + + // TODO + // override def logLikelihood(documents: RDD[(Long, Vector)]): Double = ??? + + // TODO: + // override def topicDistributions(documents: RDD[(Long, Vector)]): RDD[(Long, Vector)] = ??? + +} + +/** + * :: Experimental :: + * + * Distributed LDA model. + * This model stores the inferred topics, the full training dataset, and the topic distributions. + * When computing topics for new documents, it may give more accurate answers + * than the [[LocalLDAModel]]. + */ +@Experimental +class DistributedLDAModel private ( + private val graph: Graph[LDA.TopicCounts, LDA.TokenCount], + private val globalTopicTotals: LDA.TopicCounts, + val k: Int, + val vocabSize: Int, + private val docConcentration: Double, + private val topicConcentration: Double, + private[spark] val iterationTimes: Array[Double]) extends LDAModel { + + import LDA._ + + private[clustering] def this(state: LDA.EMOptimizer, iterationTimes: Array[Double]) = { + this(state.graph, state.globalTopicTotals, state.k, state.vocabSize, state.docConcentration, + state.topicConcentration, iterationTimes) + } + + /** + * Convert model to a local model. + * The local model stores the inferred topics but not the topic distributions for training + * documents. + */ + def toLocal: LocalLDAModel = new LocalLDAModel(topicsMatrix) + + /** + * Inferred topics, where each topic is represented by a distribution over terms. + * This is a matrix of size vocabSize x k, where each column is a topic. + * No guarantees are given about the ordering of the topics. + * + * WARNING: This matrix is collected from an RDD. Beware memory usage when vocabSize, k are large. + */ + override lazy val topicsMatrix: Matrix = { + // Collect row-major topics + val termTopicCounts: Array[(Int, TopicCounts)] = + graph.vertices.filter(_._1 < 0).map { case (termIndex, cnts) => + (index2term(termIndex), cnts) + }.collect() + // Convert to Matrix + val brzTopics = BDM.zeros[Double](vocabSize, k) + termTopicCounts.foreach { case (term, cnts) => + var j = 0 + while (j < k) { + brzTopics(term, j) = cnts(j) + j += 1 + } + } + Matrices.fromBreeze(brzTopics) + } + + override def describeTopics(maxTermsPerTopic: Int): Array[(Array[Int], Array[Double])] = { + val numTopics = k + // Note: N_k is not needed to find the top terms, but it is needed to normalize weights + // to a distribution over terms. + val N_k: TopicCounts = globalTopicTotals + val topicsInQueues: Array[BoundedPriorityQueue[(Double, Int)]] = + graph.vertices.filter(isTermVertex) + .mapPartitions { termVertices => + // For this partition, collect the most common terms for each topic in queues: + // queues(topic) = queue of (term weight, term index). + // Term weights are N_{wk} / N_k. + val queues = + Array.fill(numTopics)(new BoundedPriorityQueue[(Double, Int)](maxTermsPerTopic)) + for ((termId, n_wk) <- termVertices) { + var topic = 0 + while (topic < numTopics) { + queues(topic) += (n_wk(topic) / N_k(topic) -> index2term(termId.toInt)) + topic += 1 + } + } + Iterator(queues) + }.reduce { (q1, q2) => + q1.zip(q2).foreach { case (a, b) => a ++= b} + q1 + } + topicsInQueues.map { q => + val (termWeights, terms) = q.toArray.sortBy(-_._1).unzip + (terms.toArray, termWeights.toArray) + } + } + + // TODO + // override def logLikelihood(documents: RDD[(Long, Vector)]): Double = ??? + + /** + * Log likelihood of the observed tokens in the training set, + * given the current parameter estimates: + * log P(docs | topics, topic distributions for docs, alpha, eta) + * + * Note: + * - This excludes the prior; for that, use [[logPrior]]. + * - Even with [[logPrior]], this is NOT the same as the data log likelihood given the + * hyperparameters. + */ + lazy val logLikelihood: Double = { + val eta = topicConcentration + val alpha = docConcentration + assert(eta > 1.0) + assert(alpha > 1.0) + val N_k = globalTopicTotals + val smoothed_N_k: TopicCounts = N_k + (vocabSize * (eta - 1.0)) + // Edges: Compute token log probability from phi_{wk}, theta_{kj}. + val sendMsg: EdgeContext[TopicCounts, TokenCount, Double] => Unit = (edgeContext) => { + val N_wj = edgeContext.attr + val smoothed_N_wk: TopicCounts = edgeContext.dstAttr + (eta - 1.0) + val smoothed_N_kj: TopicCounts = edgeContext.srcAttr + (alpha - 1.0) + val phi_wk: TopicCounts = smoothed_N_wk :/ smoothed_N_k + val theta_kj: TopicCounts = normalize(smoothed_N_kj, 1.0) + val tokenLogLikelihood = N_wj * math.log(phi_wk.dot(theta_kj)) + edgeContext.sendToDst(tokenLogLikelihood) + } + graph.aggregateMessages[Double](sendMsg, _ + _) + .map(_._2).fold(0.0)(_ + _) + } + + /** + * Log probability of the current parameter estimate: + * log P(topics, topic distributions for docs | alpha, eta) + */ + lazy val logPrior: Double = { + val eta = topicConcentration + val alpha = docConcentration + // Term vertices: Compute phi_{wk}. Use to compute prior log probability. + // Doc vertex: Compute theta_{kj}. Use to compute prior log probability. + val N_k = globalTopicTotals + val smoothed_N_k: TopicCounts = N_k + (vocabSize * (eta - 1.0)) + val seqOp: (Double, (VertexId, TopicCounts)) => Double = { + case (sumPrior: Double, vertex: (VertexId, TopicCounts)) => + if (isTermVertex(vertex)) { + val N_wk = vertex._2 + val smoothed_N_wk: TopicCounts = N_wk + (eta - 1.0) + val phi_wk: TopicCounts = smoothed_N_wk :/ smoothed_N_k + (eta - 1.0) * brzSum(phi_wk.map(math.log)) + } else { + val N_kj = vertex._2 + val smoothed_N_kj: TopicCounts = N_kj + (alpha - 1.0) + val theta_kj: TopicCounts = normalize(smoothed_N_kj, 1.0) + (alpha - 1.0) * brzSum(theta_kj.map(math.log)) + } + } + graph.vertices.aggregate(0.0)(seqOp, _ + _) + } + + /** + * For each document in the training set, return the distribution over topics for that document + * ("theta_doc"). + * + * @return RDD of (document ID, topic distribution) pairs + */ + def topicDistributions: RDD[(Long, Vector)] = { + graph.vertices.filter(LDA.isDocumentVertex).map { case (docID, topicCounts) => + (docID.toLong, Vectors.fromBreeze(normalize(topicCounts, 1.0))) + } + } + + // TODO: + // override def topicDistributions(documents: RDD[(Long, Vector)]): RDD[(Long, Vector)] = ??? + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala new file mode 100644 index 0000000000000..180023922a9b0 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala @@ -0,0 +1,279 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.clustering + +import org.apache.spark.{Logging, SparkException} +import org.apache.spark.annotation.Experimental +import org.apache.spark.api.java.JavaRDD +import org.apache.spark.graphx._ +import org.apache.spark.graphx.impl.GraphImpl +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.util.MLUtils +import org.apache.spark.rdd.RDD +import org.apache.spark.util.random.XORShiftRandom + +/** + * :: Experimental :: + * + * Model produced by [[PowerIterationClustering]]. + * + * @param k number of clusters + * @param assignments an RDD of clustering [[PowerIterationClustering#Assignment]]s + */ +@Experimental +class PowerIterationClusteringModel( + val k: Int, + val assignments: RDD[PowerIterationClustering.Assignment]) extends Serializable + +/** + * :: Experimental :: + * + * Power Iteration Clustering (PIC), a scalable graph clustering algorithm developed by + * [[http://www.icml2010.org/papers/387.pdf Lin and Cohen]]. From the abstract: PIC finds a very + * low-dimensional embedding of a dataset using truncated power iteration on a normalized pair-wise + * similarity matrix of the data. + * + * @param k Number of clusters. + * @param maxIterations Maximum number of iterations of the PIC algorithm. + * @param initMode Initialization mode. + * + * @see [[http://en.wikipedia.org/wiki/Spectral_clustering Spectral clustering (Wikipedia)]] + */ +@Experimental +class PowerIterationClustering private[clustering] ( + private var k: Int, + private var maxIterations: Int, + private var initMode: String) extends Serializable { + + import org.apache.spark.mllib.clustering.PowerIterationClustering._ + + /** Constructs a PIC instance with default parameters: {k: 2, maxIterations: 100, + * initMode: "random"}. + */ + def this() = this(k = 2, maxIterations = 100, initMode = "random") + + /** + * Set the number of clusters. + */ + def setK(k: Int): this.type = { + this.k = k + this + } + + /** + * Set maximum number of iterations of the power iteration loop + */ + def setMaxIterations(maxIterations: Int): this.type = { + this.maxIterations = maxIterations + this + } + + /** + * Set the initialization mode. This can be either "random" to use a random vector + * as vertex properties, or "degree" to use normalized sum similarities. Default: random. + */ + def setInitializationMode(mode: String): this.type = { + this.initMode = mode match { + case "random" | "degree" => mode + case _ => throw new IllegalArgumentException("Invalid initialization mode: " + mode) + } + this + } + + /** + * Run the PIC algorithm. + * + * @param similarities an RDD of (i, j, s,,ij,,) tuples representing the affinity matrix, which is + * the matrix A in the PIC paper. The similarity s,,ij,, must be nonnegative. + * This is a symmetric matrix and hence s,,ij,, = s,,ji,,. For any (i, j) with + * nonzero similarity, there should be either (i, j, s,,ij,,) or + * (j, i, s,,ji,,) in the input. Tuples with i = j are ignored, because we + * assume s,,ij,, = 0.0. + * + * @return a [[PowerIterationClusteringModel]] that contains the clustering result + */ + def run(similarities: RDD[(Long, Long, Double)]): PowerIterationClusteringModel = { + val w = normalize(similarities) + val w0 = initMode match { + case "random" => randomInit(w) + case "degree" => initDegreeVector(w) + } + pic(w0) + } + + /** + * A Java-friendly version of [[PowerIterationClustering.run]]. + */ + def run(similarities: JavaRDD[(java.lang.Long, java.lang.Long, java.lang.Double)]) + : PowerIterationClusteringModel = { + run(similarities.rdd.asInstanceOf[RDD[(Long, Long, Double)]]) + } + + /** + * Runs the PIC algorithm. + * + * @param w The normalized affinity matrix, which is the matrix W in the PIC paper with + * w,,ij,, = a,,ij,, / d,,ii,, as its edge properties and the initial vector of the power + * iteration as its vertex properties. + */ + private def pic(w: Graph[Double, Double]): PowerIterationClusteringModel = { + val v = powerIter(w, maxIterations) + val assignments = kMeans(v, k).mapPartitions({ iter => + iter.map { case (id, cluster) => + new Assignment(id, cluster) + } + }, preservesPartitioning = true) + new PowerIterationClusteringModel(k, assignments) + } +} + +@Experimental +object PowerIterationClustering extends Logging { + + /** + * :: Experimental :: + * Cluster assignment. + * @param id node id + * @param cluster assigned cluster id + */ + @Experimental + class Assignment(val id: Long, val cluster: Int) extends Serializable + + /** + * Normalizes the affinity matrix (A) by row sums and returns the normalized affinity matrix (W). + */ + private[clustering] + def normalize(similarities: RDD[(Long, Long, Double)]) + : Graph[Double, Double] = { + val edges = similarities.flatMap { case (i, j, s) => + if (s < 0.0) { + throw new SparkException("Similarity must be nonnegative but found s($i, $j) = $s.") + } + if (i != j) { + Seq(Edge(i, j, s), Edge(j, i, s)) + } else { + None + } + } + val gA = Graph.fromEdges(edges, 0.0) + val vD = gA.aggregateMessages[Double]( + sendMsg = ctx => { + ctx.sendToSrc(ctx.attr) + }, + mergeMsg = _ + _, + TripletFields.EdgeOnly) + GraphImpl.fromExistingRDDs(vD, gA.edges) + .mapTriplets( + e => e.attr / math.max(e.srcAttr, MLUtils.EPSILON), + TripletFields.Src) + } + + /** + * Generates random vertex properties (v0) to start power iteration. + * + * @param g a graph representing the normalized affinity matrix (W) + * @return a graph with edges representing W and vertices representing a random vector + * with unit 1-norm + */ + private[clustering] + def randomInit(g: Graph[Double, Double]): Graph[Double, Double] = { + val r = g.vertices.mapPartitionsWithIndex( + (part, iter) => { + val random = new XORShiftRandom(part) + iter.map { case (id, _) => + (id, random.nextGaussian()) + } + }, preservesPartitioning = true).cache() + val sum = r.values.map(math.abs).sum() + val v0 = r.mapValues(x => x / sum) + GraphImpl.fromExistingRDDs(VertexRDD(v0), g.edges) + } + + /** + * Generates the degree vector as the vertex properties (v0) to start power iteration. + * It is not exactly the node degrees but just the normalized sum similarities. Call it + * as degree vector because it is used in the PIC paper. + * + * @param g a graph representing the normalized affinity matrix (W) + * @return a graph with edges representing W and vertices representing the degree vector + */ + private[clustering] + def initDegreeVector(g: Graph[Double, Double]): Graph[Double, Double] = { + val sum = g.vertices.values.sum() + val v0 = g.vertices.mapValues(_ / sum) + GraphImpl.fromExistingRDDs(VertexRDD(v0), g.edges) + } + + /** + * Runs power iteration. + * @param g input graph with edges representing the normalized affinity matrix (W) and vertices + * representing the initial vector of the power iterations. + * @param maxIterations maximum number of iterations + * @return a [[VertexRDD]] representing the pseudo-eigenvector + */ + private[clustering] + def powerIter( + g: Graph[Double, Double], + maxIterations: Int): VertexRDD[Double] = { + // the default tolerance used in the PIC paper, with a lower bound 1e-8 + val tol = math.max(1e-5 / g.vertices.count(), 1e-8) + var prevDelta = Double.MaxValue + var diffDelta = Double.MaxValue + var curG = g + for (iter <- 0 until maxIterations if math.abs(diffDelta) > tol) { + val msgPrefix = s"Iteration $iter" + // multiply W by vt + val v = curG.aggregateMessages[Double]( + sendMsg = ctx => ctx.sendToSrc(ctx.attr * ctx.dstAttr), + mergeMsg = _ + _, + TripletFields.Dst).cache() + // normalize v + val norm = v.values.map(math.abs).sum() + logInfo(s"$msgPrefix: norm(v) = $norm.") + val v1 = v.mapValues(x => x / norm) + // compare difference + val delta = curG.joinVertices(v1) { case (_, x, y) => + math.abs(x - y) + }.vertices.values.sum() + logInfo(s"$msgPrefix: delta = $delta.") + diffDelta = math.abs(delta - prevDelta) + logInfo(s"$msgPrefix: diff(delta) = $diffDelta.") + // update v + curG = GraphImpl.fromExistingRDDs(VertexRDD(v1), g.edges) + prevDelta = delta + } + curG.vertices + } + + /** + * Runs k-means clustering. + * @param v a [[VertexRDD]] representing the pseudo-eigenvector + * @param k number of clusters + * @return a [[VertexRDD]] representing the clustering assignments + */ + private[clustering] + def kMeans(v: VertexRDD[Double], k: Int): VertexRDD[Int] = { + val points = v.mapValues(x => Vectors.dense(x)).cache() + val model = new KMeans() + .setK(k) + .setRuns(5) + .setSeed(0L) + .run(points.values) + points.mapValues(p => model.predict(p)).cache() + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala index 7752c1988fdd1..f483fd1c7d2cf 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala @@ -21,7 +21,7 @@ import scala.reflect.ClassTag import org.apache.spark.Logging import org.apache.spark.SparkContext._ -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.{Experimental, DeveloperApi} import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors} import org.apache.spark.rdd.RDD import org.apache.spark.streaming.dstream.DStream @@ -29,7 +29,8 @@ import org.apache.spark.util.Utils import org.apache.spark.util.random.XORShiftRandom /** - * :: DeveloperApi :: + * :: Experimental :: + * * StreamingKMeansModel extends MLlib's KMeansModel for streaming * algorithms, so it can keep track of a continuously updated weight * associated with each cluster, and also update the model by @@ -39,8 +40,10 @@ import org.apache.spark.util.random.XORShiftRandom * generalized to incorporate forgetfullness (i.e. decay). * The update rule (for each cluster) is: * + * {{{ * c_t+1 = [(c_t * n_t * a) + (x_t * m_t)] / [n_t + m_t] * n_t+t = n_t * a + m_t + * }}} * * Where c_t is the previously estimated centroid for that cluster, * n_t is the number of points assigned to it thus far, x_t is the centroid @@ -61,7 +64,7 @@ import org.apache.spark.util.random.XORShiftRandom * as batches or points. * */ -@DeveloperApi +@Experimental class StreamingKMeansModel( override val clusterCenters: Array[Vector], val clusterWeights: Array[Double]) extends KMeansModel(clusterCenters) with Logging { @@ -140,7 +143,8 @@ class StreamingKMeansModel( } /** - * :: DeveloperApi :: + * :: Experimental :: + * * StreamingKMeans provides methods for configuring a * streaming k-means analysis, training the model on streaming, * and using the model to make predictions on streaming data. @@ -149,13 +153,15 @@ class StreamingKMeansModel( * Use a builder pattern to construct a streaming k-means analysis * in an application, like: * + * {{{ * val model = new StreamingKMeans() * .setDecayFactor(0.5) * .setK(3) * .setRandomCenters(5, 100.0) * .trainOn(DStream) + * }}} */ -@DeveloperApi +@Experimental class StreamingKMeans( var k: Int, var decayFactor: Double, diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala new file mode 100644 index 0000000000000..c6057c7f837b1 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala @@ -0,0 +1,127 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.feature + +import scala.collection.mutable.ArrayBuilder + +import org.apache.spark.annotation.Experimental +import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.stat.Statistics +import org.apache.spark.rdd.RDD + +/** + * :: Experimental :: + * Chi Squared selector model. + * + * @param selectedFeatures list of indices to select (filter). Must be ordered asc + */ +@Experimental +class ChiSqSelectorModel (val selectedFeatures: Array[Int]) extends VectorTransformer { + + require(isSorted(selectedFeatures), "Array has to be sorted asc") + + protected def isSorted(array: Array[Int]): Boolean = { + var i = 1 + while (i < array.length) { + if (array(i) < array(i-1)) return false + i += 1 + } + true + } + + /** + * Applies transformation on a vector. + * + * @param vector vector to be transformed. + * @return transformed vector. + */ + override def transform(vector: Vector): Vector = { + compress(vector, selectedFeatures) + } + + /** + * Returns a vector with features filtered. + * Preserves the order of filtered features the same as their indices are stored. + * Might be moved to Vector as .slice + * @param features vector + * @param filterIndices indices of features to filter, must be ordered asc + */ + private def compress(features: Vector, filterIndices: Array[Int]): Vector = { + features match { + case SparseVector(size, indices, values) => + val newSize = filterIndices.length + val newValues = new ArrayBuilder.ofDouble + val newIndices = new ArrayBuilder.ofInt + var i = 0 + var j = 0 + var indicesIdx = 0 + var filterIndicesIdx = 0 + while (i < indices.length && j < filterIndices.length) { + indicesIdx = indices(i) + filterIndicesIdx = filterIndices(j) + if (indicesIdx == filterIndicesIdx) { + newIndices += j + newValues += values(i) + j += 1 + i += 1 + } else { + if (indicesIdx > filterIndicesIdx) { + j += 1 + } else { + i += 1 + } + } + } + // TODO: Sparse representation might be ineffective if (newSize ~= newValues.size) + Vectors.sparse(newSize, newIndices.result(), newValues.result()) + case DenseVector(values) => + val values = features.toArray + Vectors.dense(filterIndices.map(i => values(i))) + case other => + throw new UnsupportedOperationException( + s"Only sparse and dense vectors are supported but got ${other.getClass}.") + } + } +} + +/** + * :: Experimental :: + * Creates a ChiSquared feature selector. + * @param numTopFeatures number of features that selector will select + * (ordered by statistic value descending) + */ +@Experimental +class ChiSqSelector (val numTopFeatures: Int) { + + /** + * Returns a ChiSquared feature selector. + * + * @param data an `RDD[LabeledPoint]` containing the labeled dataset with categorical features. + * Real-valued features will be treated as categorical for each distinct value. + * Apply feature discretizer before using this function. + */ + def fit(data: RDD[LabeledPoint]): ChiSqSelectorModel = { + val indices = Statistics.chiSqTest(data) + .zipWithIndex.sortBy { case (res, _) => -res.statistic } + .take(numTopFeatures) + .map { case (_, indices) => indices } + .sorted + new ChiSqSelectorModel(indices) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala index 3260f27513c7f..a89eea0e21be2 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala @@ -22,7 +22,6 @@ import breeze.linalg.{DenseVector => BDV} import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} -import org.apache.spark.mllib.rdd.RDDFunctions._ import org.apache.spark.rdd.RDD /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala index 3c2091732f9b0..6ae6917eae595 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala @@ -18,15 +18,14 @@ package org.apache.spark.mllib.feature import org.apache.spark.Logging -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} -import org.apache.spark.mllib.rdd.RDDFunctions._ import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer import org.apache.spark.rdd.RDD /** * :: Experimental :: - * Standardizes features by removing the mean and scaling to unit variance using column summary + * Standardizes features by removing the mean and scaling to unit std using column summary * statistics on the samples in the training set. * * @param withMean False by default. Centers the data with mean before scaling. It will build a @@ -53,7 +52,11 @@ class StandardScaler(withMean: Boolean, withStd: Boolean) extends Logging { val summary = data.treeAggregate(new MultivariateOnlineSummarizer)( (aggregator, data) => aggregator.add(data), (aggregator1, aggregator2) => aggregator1.merge(aggregator2)) - new StandardScalerModel(withMean, withStd, summary.mean, summary.variance) + new StandardScalerModel( + Vectors.dense(summary.variance.toArray.map(v => math.sqrt(v))), + summary.mean, + withStd, + withMean) } } @@ -61,28 +64,43 @@ class StandardScaler(withMean: Boolean, withStd: Boolean) extends Logging { * :: Experimental :: * Represents a StandardScaler model that can transform vectors. * - * @param withMean whether to center the data before scaling - * @param withStd whether to scale the data to have unit standard deviation + * @param std column standard deviation values * @param mean column mean values - * @param variance column variance values + * @param withStd whether to scale the data to have unit standard deviation + * @param withMean whether to center the data before scaling */ @Experimental -class StandardScalerModel private[mllib] ( - val withMean: Boolean, - val withStd: Boolean, +class StandardScalerModel ( + val std: Vector, val mean: Vector, - val variance: Vector) extends VectorTransformer { - - require(mean.size == variance.size) + var withStd: Boolean, + var withMean: Boolean) extends VectorTransformer { - private lazy val factor: Array[Double] = { - val f = Array.ofDim[Double](variance.size) - var i = 0 - while (i < f.size) { - f(i) = if (variance(i) != 0.0) 1.0 / math.sqrt(variance(i)) else 0.0 - i += 1 + def this(std: Vector, mean: Vector) { + this(std, mean, withStd = std != null, withMean = mean != null) + require(this.withStd || this.withMean, + "at least one of std or mean vectors must be provided") + if (this.withStd && this.withMean) { + require(mean.size == std.size, + "mean and std vectors must have equal size if both are provided") } - f + } + + def this(std: Vector) = this(std, null) + + @DeveloperApi + def setWithMean(withMean: Boolean): this.type = { + require(!(withMean && this.mean == null),"cannot set withMean to true while mean is null") + this.withMean = withMean + this + } + + @DeveloperApi + def setWithStd(withStd: Boolean): this.type = { + require(!(withStd && this.std == null), + "cannot set withStd to true while std is null") + this.withStd = withStd + this } // Since `shift` will be only used in `withMean` branch, we have it as @@ -94,8 +112,8 @@ class StandardScalerModel private[mllib] ( * Applies standardization transformation on a vector. * * @param vector Vector to be standardized. - * @return Standardized vector. If the variance of a column is zero, it will return default `0.0` - * for the column with zero variance. + * @return Standardized vector. If the std of a column is zero, it will return default `0.0` + * for the column with zero std. */ override def transform(vector: Vector): Vector = { require(mean.size == vector.size) @@ -109,11 +127,9 @@ class StandardScalerModel private[mllib] ( val values = vs.clone() val size = values.size if (withStd) { - // Having a local reference of `factor` to avoid overhead as the comment before. - val localFactor = factor var i = 0 while (i < size) { - values(i) = (values(i) - localShift(i)) * localFactor(i) + values(i) = if (std(i) != 0.0) (values(i) - localShift(i)) * (1.0 / std(i)) else 0.0 i += 1 } } else { @@ -127,15 +143,13 @@ class StandardScalerModel private[mllib] ( case v => throw new IllegalArgumentException("Do not support vector type " + v.getClass) } } else if (withStd) { - // Having a local reference of `factor` to avoid overhead as the comment before. - val localFactor = factor vector match { case DenseVector(vs) => val values = vs.clone() val size = values.size var i = 0 while(i < size) { - values(i) *= localFactor(i) + values(i) *= (if (std(i) != 0.0) 1.0 / std(i) else 0.0) i += 1 } Vectors.dense(values) @@ -146,7 +160,7 @@ class StandardScalerModel private[mllib] ( val nnz = values.size var i = 0 while (i < nnz) { - values(i) *= localFactor(indices(i)) + values(i) *= (if (std(indices(i)) != 0.0) 1.0 / std(indices(i)) else 0.0) i += 1 } Vectors.sparse(size, indices, values) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index d25a7cd5b439d..59a79e5c6a4ac 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -21,7 +21,7 @@ import java.lang.{Iterable => JavaIterable} import scala.collection.JavaConverters._ import scala.collection.mutable -import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.ArrayBuilder import com.github.fommil.netlib.BLAS.{getInstance => blas} @@ -272,7 +272,7 @@ class Word2Vec extends Serializable with Logging { def hasNext: Boolean = iter.hasNext def next(): Array[Int] = { - var sentence = new ArrayBuffer[Int] + val sentence = ArrayBuilder.make[Int] var sentenceLength = 0 while (iter.hasNext && sentenceLength < MAX_SENTENCE_LENGTH) { val word = bcVocabHash.value.get(iter.next()) @@ -283,13 +283,20 @@ class Word2Vec extends Serializable with Logging { case None => } } - sentence.toArray + sentence.result() } } } val newSentences = sentences.repartition(numPartitions).cache() val initRandom = new XORShiftRandom(seed) + + if (vocabSize.toLong * vectorSize * 8 >= Int.MaxValue) { + throw new RuntimeException("Please increase minCount or decrease vectorSize in Word2Vec" + + " to avoid an OOM. You are highly recommended to make your vocabSize*vectorSize, " + + "which is " + vocabSize + "*" + vectorSize + " for now, less than `Int.MaxValue/8`.") + } + val syn0Global = Array.fill[Float](vocabSize * vectorSize)((initRandom.nextFloat() - 0.5f) / vectorSize) val syn1Global = new Array[Float](vocabSize * vectorSize) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala new file mode 100644 index 0000000000000..efa8459d3cdba --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala @@ -0,0 +1,212 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.fpm + +import java.{util => ju} +import java.lang.{Iterable => JavaIterable} + +import scala.collection.mutable +import scala.collection.JavaConverters._ +import scala.reflect.ClassTag + +import org.apache.spark.{HashPartitioner, Logging, Partitioner, SparkException} +import org.apache.spark.annotation.Experimental +import org.apache.spark.api.java.JavaRDD +import org.apache.spark.api.java.JavaSparkContext.fakeClassTag +import org.apache.spark.mllib.fpm.FPGrowth.FreqItemset +import org.apache.spark.rdd.RDD +import org.apache.spark.storage.StorageLevel + +/** + * :: Experimental :: + * + * Model trained by [[FPGrowth]], which holds frequent itemsets. + * @param freqItemsets frequent itemset, which is an RDD of [[FreqItemset]] + * @tparam Item item type + */ +@Experimental +class FPGrowthModel[Item: ClassTag](val freqItemsets: RDD[FreqItemset[Item]]) extends Serializable + +/** + * :: Experimental :: + * + * A parallel FP-growth algorithm to mine frequent itemsets. The algorithm is described in + * [[http://dx.doi.org/10.1145/1454008.1454027 Li et al., PFP: Parallel FP-Growth for Query + * Recommendation]]. PFP distributes computation in such a way that each worker executes an + * independent group of mining tasks. The FP-Growth algorithm is described in + * [[http://dx.doi.org/10.1145/335191.335372 Han et al., Mining frequent patterns without candidate + * generation]]. + * + * @param minSupport the minimal support level of the frequent pattern, any pattern appears + * more than (minSupport * size-of-the-dataset) times will be output + * @param numPartitions number of partitions used by parallel FP-growth + * + * @see [[http://en.wikipedia.org/wiki/Association_rule_learning Association rule learning + * (Wikipedia)]] + */ +@Experimental +class FPGrowth private ( + private var minSupport: Double, + private var numPartitions: Int) extends Logging with Serializable { + + /** + * Constructs a default instance with default parameters {minSupport: `0.3`, numPartitions: same + * as the input data}. + */ + def this() = this(0.3, -1) + + /** + * Sets the minimal support level (default: `0.3`). + */ + def setMinSupport(minSupport: Double): this.type = { + this.minSupport = minSupport + this + } + + /** + * Sets the number of partitions used by parallel FP-growth (default: same as input data). + */ + def setNumPartitions(numPartitions: Int): this.type = { + this.numPartitions = numPartitions + this + } + + /** + * Computes an FP-Growth model that contains frequent itemsets. + * @param data input data set, each element contains a transaction + * @return an [[FPGrowthModel]] + */ + def run[Item: ClassTag](data: RDD[Array[Item]]): FPGrowthModel[Item] = { + if (data.getStorageLevel == StorageLevel.NONE) { + logWarning("Input data is not cached.") + } + val count = data.count() + val minCount = math.ceil(minSupport * count).toLong + val numParts = if (numPartitions > 0) numPartitions else data.partitions.length + val partitioner = new HashPartitioner(numParts) + val freqItems = genFreqItems(data, minCount, partitioner) + val freqItemsets = genFreqItemsets(data, minCount, freqItems, partitioner) + new FPGrowthModel(freqItemsets) + } + + def run[Item, Basket <: JavaIterable[Item]](data: JavaRDD[Basket]): FPGrowthModel[Item] = { + implicit val tag = fakeClassTag[Item] + run(data.rdd.map(_.asScala.toArray)) + } + + /** + * Generates frequent items by filtering the input data using minimal support level. + * @param minCount minimum count for frequent itemsets + * @param partitioner partitioner used to distribute items + * @return array of frequent pattern ordered by their frequencies + */ + private def genFreqItems[Item: ClassTag]( + data: RDD[Array[Item]], + minCount: Long, + partitioner: Partitioner): Array[Item] = { + data.flatMap { t => + val uniq = t.toSet + if (t.size != uniq.size) { + throw new SparkException(s"Items in a transaction must be unique but got ${t.toSeq}.") + } + t + }.map(v => (v, 1L)) + .reduceByKey(partitioner, _ + _) + .filter(_._2 >= minCount) + .collect() + .sortBy(-_._2) + .map(_._1) + } + + /** + * Generate frequent itemsets by building FP-Trees, the extraction is done on each partition. + * @param data transactions + * @param minCount minimum count for frequent itemsets + * @param freqItems frequent items + * @param partitioner partitioner used to distribute transactions + * @return an RDD of (frequent itemset, count) + */ + private def genFreqItemsets[Item: ClassTag]( + data: RDD[Array[Item]], + minCount: Long, + freqItems: Array[Item], + partitioner: Partitioner): RDD[FreqItemset[Item]] = { + val itemToRank = freqItems.zipWithIndex.toMap + data.flatMap { transaction => + genCondTransactions(transaction, itemToRank, partitioner) + }.aggregateByKey(new FPTree[Int], partitioner.numPartitions)( + (tree, transaction) => tree.add(transaction, 1L), + (tree1, tree2) => tree1.merge(tree2)) + .flatMap { case (part, tree) => + tree.extract(minCount, x => partitioner.getPartition(x) == part) + }.map { case (ranks, count) => + new FreqItemset(ranks.map(i => freqItems(i)).toArray, count) + } + } + + /** + * Generates conditional transactions. + * @param transaction a transaction + * @param itemToRank map from item to their rank + * @param partitioner partitioner used to distribute transactions + * @return a map of (target partition, conditional transaction) + */ + private def genCondTransactions[Item: ClassTag]( + transaction: Array[Item], + itemToRank: Map[Item, Int], + partitioner: Partitioner): mutable.Map[Int, Array[Int]] = { + val output = mutable.Map.empty[Int, Array[Int]] + // Filter the basket by frequent items pattern and sort their ranks. + val filtered = transaction.flatMap(itemToRank.get) + ju.Arrays.sort(filtered) + val n = filtered.length + var i = n - 1 + while (i >= 0) { + val item = filtered(i) + val part = partitioner.getPartition(item) + if (!output.contains(part)) { + output(part) = filtered.slice(0, i + 1) + } + i -= 1 + } + output + } +} + +/** + * :: Experimental :: + */ +@Experimental +object FPGrowth { + + /** + * Frequent itemset. + * @param items items in this itemset. Java users should call [[FreqItemset#javaItems]] instead. + * @param freq frequency + * @tparam Item item type + */ + class FreqItemset[Item](val items: Array[Item], val freq: Long) extends Serializable { + + /** + * Returns items in a Java List. + */ + def javaItems: java.util.List[Item] = { + items.toList.asJava + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPTree.scala new file mode 100644 index 0000000000000..1d2d777c00793 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPTree.scala @@ -0,0 +1,134 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.fpm + +import scala.collection.mutable +import scala.collection.mutable.ListBuffer + +/** + * FP-Tree data structure used in FP-Growth. + * @tparam T item type + */ +private[fpm] class FPTree[T] extends Serializable { + + import FPTree._ + + val root: Node[T] = new Node(null) + + private val summaries: mutable.Map[T, Summary[T]] = mutable.Map.empty + + /** Adds a transaction with count. */ + def add(t: Iterable[T], count: Long = 1L): this.type = { + require(count > 0) + var curr = root + curr.count += count + t.foreach { item => + val summary = summaries.getOrElseUpdate(item, new Summary) + summary.count += count + val child = curr.children.getOrElseUpdate(item, { + val newNode = new Node(curr) + newNode.item = item + summary.nodes += newNode + newNode + }) + child.count += count + curr = child + } + this + } + + /** Merges another FP-Tree. */ + def merge(other: FPTree[T]): this.type = { + other.transactions.foreach { case (t, c) => + add(t, c) + } + this + } + + /** Gets a subtree with the suffix. */ + private def project(suffix: T): FPTree[T] = { + val tree = new FPTree[T] + if (summaries.contains(suffix)) { + val summary = summaries(suffix) + summary.nodes.foreach { node => + var t = List.empty[T] + var curr = node.parent + while (!curr.isRoot) { + t = curr.item :: t + curr = curr.parent + } + tree.add(t, node.count) + } + } + tree + } + + /** Returns all transactions in an iterator. */ + def transactions: Iterator[(List[T], Long)] = getTransactions(root) + + /** Returns all transactions under this node. */ + private def getTransactions(node: Node[T]): Iterator[(List[T], Long)] = { + var count = node.count + node.children.iterator.flatMap { case (item, child) => + getTransactions(child).map { case (t, c) => + count -= c + (item :: t, c) + } + } ++ { + if (count > 0) { + Iterator.single((Nil, count)) + } else { + Iterator.empty + } + } + } + + /** Extracts all patterns with valid suffix and minimum count. */ + def extract( + minCount: Long, + validateSuffix: T => Boolean = _ => true): Iterator[(List[T], Long)] = { + summaries.iterator.flatMap { case (item, summary) => + if (validateSuffix(item) && summary.count >= minCount) { + Iterator.single((item :: Nil, summary.count)) ++ + project(item).extract(minCount).map { case (t, c) => + (item :: t, c) + } + } else { + Iterator.empty + } + } + } +} + +private[fpm] object FPTree { + + /** Representing a node in an FP-Tree. */ + class Node[T](val parent: Node[T]) extends Serializable { + var item: T = _ + var count: Long = 0L + val children: mutable.Map[T, Node[T]] = mutable.Map.empty + + def isRoot: Boolean = parent == null + } + + /** Summary of a item in an FP-Tree. */ + private class Summary[T] extends Serializable { + var count: Long = 0L + val nodes: ListBuffer[Node[T]] = ListBuffer.empty + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala new file mode 100644 index 0000000000000..6e5dd119dd653 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala @@ -0,0 +1,171 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.impl + +import scala.collection.mutable + +import org.apache.hadoop.fs.{Path, FileSystem} + +import org.apache.spark.Logging +import org.apache.spark.graphx.Graph +import org.apache.spark.storage.StorageLevel + + +/** + * This class helps with persisting and checkpointing Graphs. + * Specifically, it automatically handles persisting and (optionally) checkpointing, as well as + * unpersisting and removing checkpoint files. + * + * Users should call [[PeriodicGraphCheckpointer.updateGraph()]] when a new graph has been created, + * before the graph has been materialized. After updating [[PeriodicGraphCheckpointer]], users are + * responsible for materializing the graph to ensure that persisting and checkpointing actually + * occur. + * + * When [[PeriodicGraphCheckpointer.updateGraph()]] is called, this does the following: + * - Persist new graph (if not yet persisted), and put in queue of persisted graphs. + * - Unpersist graphs from queue until there are at most 3 persisted graphs. + * - If using checkpointing and the checkpoint interval has been reached, + * - Checkpoint the new graph, and put in a queue of checkpointed graphs. + * - Remove older checkpoints. + * + * WARNINGS: + * - This class should NOT be copied (since copies may conflict on which Graphs should be + * checkpointed). + * - This class removes checkpoint files once later graphs have been checkpointed. + * However, references to the older graphs will still return isCheckpointed = true. + * + * Example usage: + * {{{ + * val (graph1, graph2, graph3, ...) = ... + * val cp = new PeriodicGraphCheckpointer(graph1, dir, 2) + * graph1.vertices.count(); graph1.edges.count() + * // persisted: graph1 + * cp.updateGraph(graph2) + * graph2.vertices.count(); graph2.edges.count() + * // persisted: graph1, graph2 + * // checkpointed: graph2 + * cp.updateGraph(graph3) + * graph3.vertices.count(); graph3.edges.count() + * // persisted: graph1, graph2, graph3 + * // checkpointed: graph2 + * cp.updateGraph(graph4) + * graph4.vertices.count(); graph4.edges.count() + * // persisted: graph2, graph3, graph4 + * // checkpointed: graph4 + * cp.updateGraph(graph5) + * graph5.vertices.count(); graph5.edges.count() + * // persisted: graph3, graph4, graph5 + * // checkpointed: graph4 + * }}} + * + * @param currentGraph Initial graph + * @param checkpointInterval Graphs will be checkpointed at this interval + * @tparam VD Vertex descriptor type + * @tparam ED Edge descriptor type + * + * TODO: Generalize this for Graphs and RDDs, and move it out of MLlib. + */ +private[mllib] class PeriodicGraphCheckpointer[VD, ED]( + var currentGraph: Graph[VD, ED], + val checkpointInterval: Int) extends Logging { + + /** FIFO queue of past checkpointed RDDs */ + private val checkpointQueue = mutable.Queue[Graph[VD, ED]]() + + /** FIFO queue of past persisted RDDs */ + private val persistedQueue = mutable.Queue[Graph[VD, ED]]() + + /** Number of times [[updateGraph()]] has been called */ + private var updateCount = 0 + + /** + * Spark Context for the Graphs given to this checkpointer. + * NOTE: This code assumes that only one SparkContext is used for the given graphs. + */ + private val sc = currentGraph.vertices.sparkContext + + updateGraph(currentGraph) + + /** + * Update [[currentGraph]] with a new graph. Handle persistence and checkpointing as needed. + * Since this handles persistence and checkpointing, this should be called before the graph + * has been materialized. + * + * @param newGraph New graph created from previous graphs in the lineage. + */ + def updateGraph(newGraph: Graph[VD, ED]): Unit = { + if (newGraph.vertices.getStorageLevel == StorageLevel.NONE) { + newGraph.persist() + } + persistedQueue.enqueue(newGraph) + // We try to maintain 2 Graphs in persistedQueue to support the semantics of this class: + // Users should call [[updateGraph()]] when a new graph has been created, + // before the graph has been materialized. + while (persistedQueue.size > 3) { + val graphToUnpersist = persistedQueue.dequeue() + graphToUnpersist.unpersist(blocking = false) + } + updateCount += 1 + + // Handle checkpointing (after persisting) + if ((updateCount % checkpointInterval) == 0 && sc.getCheckpointDir.nonEmpty) { + // Add new checkpoint before removing old checkpoints. + newGraph.checkpoint() + checkpointQueue.enqueue(newGraph) + // Remove checkpoints before the latest one. + var canDelete = true + while (checkpointQueue.size > 1 && canDelete) { + // Delete the oldest checkpoint only if the next checkpoint exists. + if (checkpointQueue.get(1).get.isCheckpointed) { + removeCheckpointFile() + } else { + canDelete = false + } + } + } + } + + /** + * Call this at the end to delete any remaining checkpoint files. + */ + def deleteAllCheckpoints(): Unit = { + while (checkpointQueue.size > 0) { + removeCheckpointFile() + } + } + + /** + * Dequeue the oldest checkpointed Graph, and remove its checkpoint files. + * This prints a warning but does not fail if the files cannot be removed. + */ + private def removeCheckpointFile(): Unit = { + val old = checkpointQueue.dequeue() + // Since the old checkpoint is not deleted by Spark, we manually delete it. + val fs = FileSystem.get(sc.hadoopConfiguration) + old.getCheckpointFiles.foreach { checkpointFile => + try { + fs.delete(new Path(checkpointFile), true) + } catch { + case e: Exception => + logWarning("PeriodicGraphCheckpointer could not remove old checkpoint file: " + + checkpointFile) + } + } + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala index 3414daccd7ca4..87052e1ba8539 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala @@ -235,12 +235,24 @@ private[spark] object BLAS extends Serializable with Logging { * @param x the vector x that contains the n elements. * @param A the symmetric matrix A. Size of n x n. */ - def syr(alpha: Double, x: DenseVector, A: DenseMatrix) { + def syr(alpha: Double, x: Vector, A: DenseMatrix) { val mA = A.numRows val nA = A.numCols - require(mA == nA, s"A is not a symmetric matrix. A: $mA x $nA") + require(mA == nA, s"A is not a square matrix (and hence is not symmetric). A: $mA x $nA") require(mA == x.size, s"The size of x doesn't match the rank of A. A: $mA x $nA, x: ${x.size}") + x match { + case dv: DenseVector => syr(alpha, dv, A) + case sv: SparseVector => syr(alpha, sv, A) + case _ => + throw new IllegalArgumentException(s"syr doesn't support vector type ${x.getClass}.") + } + } + + private def syr(alpha: Double, x: DenseVector, A: DenseMatrix) { + val nA = A.numRows + val mA = A.numCols + nativeBLAS.dsyr("U", x.size, alpha, x.values, 1, A.values, nA) // Fill lower triangular part of A @@ -255,82 +267,77 @@ private[spark] object BLAS extends Serializable with Logging { } } + private def syr(alpha: Double, x: SparseVector, A: DenseMatrix) { + val mA = A.numCols + val xIndices = x.indices + val xValues = x.values + val nnz = xValues.length + val Avalues = A.values + + var i = 0 + while (i < nnz) { + val multiplier = alpha * xValues(i) + val offset = xIndices(i) * mA + var j = 0 + while (j < nnz) { + Avalues(xIndices(j) + offset) += multiplier * xValues(j) + j += 1 + } + i += 1 + } + } + /** * C := alpha * A * B + beta * C - * @param transA whether to use the transpose of matrix A (true), or A itself (false). - * @param transB whether to use the transpose of matrix B (true), or B itself (false). * @param alpha a scalar to scale the multiplication A * B. * @param A the matrix A that will be left multiplied to B. Size of m x k. * @param B the matrix B that will be left multiplied by A. Size of k x n. * @param beta a scalar that can be used to scale matrix C. - * @param C the resulting matrix C. Size of m x n. + * @param C the resulting matrix C. Size of m x n. C.isTransposed must be false. */ def gemm( - transA: Boolean, - transB: Boolean, alpha: Double, A: Matrix, B: DenseMatrix, beta: Double, C: DenseMatrix): Unit = { + require(!C.isTransposed, + "The matrix C cannot be the product of a transpose() call. C.isTransposed must be false.") if (alpha == 0.0) { logDebug("gemm: alpha is equal to 0. Returning C.") } else { A match { - case sparse: SparseMatrix => - gemm(transA, transB, alpha, sparse, B, beta, C) - case dense: DenseMatrix => - gemm(transA, transB, alpha, dense, B, beta, C) + case sparse: SparseMatrix => gemm(alpha, sparse, B, beta, C) + case dense: DenseMatrix => gemm(alpha, dense, B, beta, C) case _ => throw new IllegalArgumentException(s"gemm doesn't support matrix type ${A.getClass}.") } } } - /** - * C := alpha * A * B + beta * C - * - * @param alpha a scalar to scale the multiplication A * B. - * @param A the matrix A that will be left multiplied to B. Size of m x k. - * @param B the matrix B that will be left multiplied by A. Size of k x n. - * @param beta a scalar that can be used to scale matrix C. - * @param C the resulting matrix C. Size of m x n. - */ - def gemm( - alpha: Double, - A: Matrix, - B: DenseMatrix, - beta: Double, - C: DenseMatrix): Unit = { - gemm(false, false, alpha, A, B, beta, C) - } - /** * C := alpha * A * B + beta * C * For `DenseMatrix` A. */ private def gemm( - transA: Boolean, - transB: Boolean, alpha: Double, A: DenseMatrix, B: DenseMatrix, beta: Double, C: DenseMatrix): Unit = { - val mA: Int = if (!transA) A.numRows else A.numCols - val nB: Int = if (!transB) B.numCols else B.numRows - val kA: Int = if (!transA) A.numCols else A.numRows - val kB: Int = if (!transB) B.numRows else B.numCols - val tAstr = if (!transA) "N" else "T" - val tBstr = if (!transB) "N" else "T" - - require(kA == kB, s"The columns of A don't match the rows of B. A: $kA, B: $kB") - require(mA == C.numRows, s"The rows of C don't match the rows of A. C: ${C.numRows}, A: $mA") - require(nB == C.numCols, - s"The columns of C don't match the columns of B. C: ${C.numCols}, A: $nB") - - nativeBLAS.dgemm(tAstr, tBstr, mA, nB, kA, alpha, A.values, A.numRows, B.values, B.numRows, - beta, C.values, C.numRows) + val tAstr = if (A.isTransposed) "T" else "N" + val tBstr = if (B.isTransposed) "T" else "N" + val lda = if (!A.isTransposed) A.numRows else A.numCols + val ldb = if (!B.isTransposed) B.numRows else B.numCols + + require(A.numCols == B.numRows, + s"The columns of A don't match the rows of B. A: ${A.numCols}, B: ${B.numRows}") + require(A.numRows == C.numRows, + s"The rows of C don't match the rows of A. C: ${C.numRows}, A: ${A.numRows}") + require(B.numCols == C.numCols, + s"The columns of C don't match the columns of B. C: ${C.numCols}, A: ${B.numCols}") + nativeBLAS.dgemm(tAstr, tBstr, A.numRows, B.numCols, A.numCols, alpha, A.values, lda, + B.values, ldb, beta, C.values, C.numRows) } /** @@ -338,17 +345,15 @@ private[spark] object BLAS extends Serializable with Logging { * For `SparseMatrix` A. */ private def gemm( - transA: Boolean, - transB: Boolean, alpha: Double, A: SparseMatrix, B: DenseMatrix, beta: Double, C: DenseMatrix): Unit = { - val mA: Int = if (!transA) A.numRows else A.numCols - val nB: Int = if (!transB) B.numCols else B.numRows - val kA: Int = if (!transA) A.numCols else A.numRows - val kB: Int = if (!transB) B.numRows else B.numCols + val mA: Int = A.numRows + val nB: Int = B.numCols + val kA: Int = A.numCols + val kB: Int = B.numRows require(kA == kB, s"The columns of A don't match the rows of B. A: $kA, B: $kB") require(mA == C.numRows, s"The rows of C don't match the rows of A. C: ${C.numRows}, A: $mA") @@ -358,23 +363,23 @@ private[spark] object BLAS extends Serializable with Logging { val Avals = A.values val Bvals = B.values val Cvals = C.values - val Arows = if (!transA) A.rowIndices else A.colPtrs - val Acols = if (!transA) A.colPtrs else A.rowIndices + val ArowIndices = A.rowIndices + val AcolPtrs = A.colPtrs // Slicing is easy in this case. This is the optimal multiplication setting for sparse matrices - if (transA){ + if (A.isTransposed){ var colCounterForB = 0 - if (!transB) { // Expensive to put the check inside the loop + if (!B.isTransposed) { // Expensive to put the check inside the loop while (colCounterForB < nB) { var rowCounterForA = 0 val Cstart = colCounterForB * mA val Bstart = colCounterForB * kA while (rowCounterForA < mA) { - var i = Arows(rowCounterForA) - val indEnd = Arows(rowCounterForA + 1) + var i = AcolPtrs(rowCounterForA) + val indEnd = AcolPtrs(rowCounterForA + 1) var sum = 0.0 while (i < indEnd) { - sum += Avals(i) * Bvals(Bstart + Acols(i)) + sum += Avals(i) * Bvals(Bstart + ArowIndices(i)) i += 1 } val Cindex = Cstart + rowCounterForA @@ -385,19 +390,19 @@ private[spark] object BLAS extends Serializable with Logging { } } else { while (colCounterForB < nB) { - var rowCounter = 0 + var rowCounterForA = 0 val Cstart = colCounterForB * mA - while (rowCounter < mA) { - var i = Arows(rowCounter) - val indEnd = Arows(rowCounter + 1) + while (rowCounterForA < mA) { + var i = AcolPtrs(rowCounterForA) + val indEnd = AcolPtrs(rowCounterForA + 1) var sum = 0.0 while (i < indEnd) { - sum += Avals(i) * B(colCounterForB, Acols(i)) + sum += Avals(i) * B(ArowIndices(i), colCounterForB) i += 1 } - val Cindex = Cstart + rowCounter + val Cindex = Cstart + rowCounterForA Cvals(Cindex) = beta * Cvals(Cindex) + sum * alpha - rowCounter += 1 + rowCounterForA += 1 } colCounterForB += 1 } @@ -410,17 +415,17 @@ private[spark] object BLAS extends Serializable with Logging { // Perform matrix multiplication and add to C. The rows of A are multiplied by the columns of // B, and added to C. var colCounterForB = 0 // the column to be updated in C - if (!transB) { // Expensive to put the check inside the loop + if (!B.isTransposed) { // Expensive to put the check inside the loop while (colCounterForB < nB) { var colCounterForA = 0 // The column of A to multiply with the row of B val Bstart = colCounterForB * kB val Cstart = colCounterForB * mA while (colCounterForA < kA) { - var i = Acols(colCounterForA) - val indEnd = Acols(colCounterForA + 1) + var i = AcolPtrs(colCounterForA) + val indEnd = AcolPtrs(colCounterForA + 1) val Bval = Bvals(Bstart + colCounterForA) * alpha while (i < indEnd) { - Cvals(Cstart + Arows(i)) += Avals(i) * Bval + Cvals(Cstart + ArowIndices(i)) += Avals(i) * Bval i += 1 } colCounterForA += 1 @@ -432,11 +437,11 @@ private[spark] object BLAS extends Serializable with Logging { var colCounterForA = 0 // The column of A to multiply with the row of B val Cstart = colCounterForB * mA while (colCounterForA < kA) { - var i = Acols(colCounterForA) - val indEnd = Acols(colCounterForA + 1) - val Bval = B(colCounterForB, colCounterForA) * alpha + var i = AcolPtrs(colCounterForA) + val indEnd = AcolPtrs(colCounterForA + 1) + val Bval = B(colCounterForA, colCounterForB) * alpha while (i < indEnd) { - Cvals(Cstart + Arows(i)) += Avals(i) * Bval + Cvals(Cstart + ArowIndices(i)) += Avals(i) * Bval i += 1 } colCounterForA += 1 @@ -449,7 +454,6 @@ private[spark] object BLAS extends Serializable with Logging { /** * y := alpha * A * x + beta * y - * @param trans whether to use the transpose of matrix A (true), or A itself (false). * @param alpha a scalar to scale the multiplication A * x. * @param A the matrix A that will be left multiplied to x. Size of m x n. * @param x the vector x that will be left multiplied by A. Size of n x 1. @@ -457,65 +461,43 @@ private[spark] object BLAS extends Serializable with Logging { * @param y the resulting vector y. Size of m x 1. */ def gemv( - trans: Boolean, alpha: Double, A: Matrix, x: DenseVector, beta: Double, y: DenseVector): Unit = { - - val mA: Int = if (!trans) A.numRows else A.numCols - val nx: Int = x.size - val nA: Int = if (!trans) A.numCols else A.numRows - - require(nA == nx, s"The columns of A don't match the number of elements of x. A: $nA, x: $nx") - require(mA == y.size, - s"The rows of A don't match the number of elements of y. A: $mA, y:${y.size}}") + require(A.numCols == x.size, + s"The columns of A don't match the number of elements of x. A: ${A.numCols}, x: ${x.size}") + require(A.numRows == y.size, + s"The rows of A don't match the number of elements of y. A: ${A.numRows}, y:${y.size}}") if (alpha == 0.0) { logDebug("gemv: alpha is equal to 0. Returning y.") } else { A match { case sparse: SparseMatrix => - gemv(trans, alpha, sparse, x, beta, y) + gemv(alpha, sparse, x, beta, y) case dense: DenseMatrix => - gemv(trans, alpha, dense, x, beta, y) + gemv(alpha, dense, x, beta, y) case _ => throw new IllegalArgumentException(s"gemv doesn't support matrix type ${A.getClass}.") } } } - /** - * y := alpha * A * x + beta * y - * - * @param alpha a scalar to scale the multiplication A * x. - * @param A the matrix A that will be left multiplied to x. Size of m x n. - * @param x the vector x that will be left multiplied by A. Size of n x 1. - * @param beta a scalar that can be used to scale vector y. - * @param y the resulting vector y. Size of m x 1. - */ - def gemv( - alpha: Double, - A: Matrix, - x: DenseVector, - beta: Double, - y: DenseVector): Unit = { - gemv(false, alpha, A, x, beta, y) - } - /** * y := alpha * A * x + beta * y * For `DenseMatrix` A. */ private def gemv( - trans: Boolean, alpha: Double, A: DenseMatrix, x: DenseVector, beta: Double, y: DenseVector): Unit = { - val tStrA = if (!trans) "N" else "T" - nativeBLAS.dgemv(tStrA, A.numRows, A.numCols, alpha, A.values, A.numRows, x.values, 1, beta, + val tStrA = if (A.isTransposed) "T" else "N" + val mA = if (!A.isTransposed) A.numRows else A.numCols + val nA = if (!A.isTransposed) A.numCols else A.numRows + nativeBLAS.dgemv(tStrA, mA, nA, alpha, A.values, mA, x.values, 1, beta, y.values, 1) } @@ -524,24 +506,21 @@ private[spark] object BLAS extends Serializable with Logging { * For `SparseMatrix` A. */ private def gemv( - trans: Boolean, alpha: Double, A: SparseMatrix, x: DenseVector, beta: Double, y: DenseVector): Unit = { - val xValues = x.values val yValues = y.values - - val mA: Int = if (!trans) A.numRows else A.numCols - val nA: Int = if (!trans) A.numCols else A.numRows + val mA: Int = A.numRows + val nA: Int = A.numCols val Avals = A.values - val Arows = if (!trans) A.rowIndices else A.colPtrs - val Acols = if (!trans) A.colPtrs else A.rowIndices + val Arows = if (!A.isTransposed) A.rowIndices else A.colPtrs + val Acols = if (!A.isTransposed) A.colPtrs else A.rowIndices // Slicing is easy in this case. This is the optimal multiplication setting for sparse matrices - if (trans) { + if (A.isTransposed) { var rowCounter = 0 while (rowCounter < mA) { var i = Arows(rowCounter) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/EigenValueDecomposition.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/EigenValueDecomposition.scala index 3515461b52493..866936aa4f118 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/EigenValueDecomposition.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/EigenValueDecomposition.scala @@ -79,6 +79,9 @@ private[mllib] object EigenValueDecomposition { // Mode 1: A*x = lambda*x, A symmetric iparam(6) = 1 + require(n * ncv.toLong <= Integer.MAX_VALUE && ncv * (ncv.toLong + 8) <= Integer.MAX_VALUE, + s"k = $k and/or n = $n are too large to compute an eigendecomposition") + var ido = new intW(0) var info = new intW(0) var resid = new Array[Double](n) @@ -114,7 +117,7 @@ private[mllib] object EigenValueDecomposition { info.`val` match { case 1 => throw new IllegalStateException("ARPACK returns non-zero info = " + info.`val` + " Maximum number of iterations taken. (Refer ARPACK user guide for details)") - case 2 => throw new IllegalStateException("ARPACK returns non-zero info = " + info.`val` + + case 3 => throw new IllegalStateException("ARPACK returns non-zero info = " + info.`val` + " No shifts could be applied. Try to increase NCV. " + "(Refer ARPACK user guide for details)") case _ => throw new IllegalStateException("ARPACK returns non-zero info = " + info.`val` + diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala index 5a7281ec6dc3c..0e4a4d0085895 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala @@ -34,14 +34,23 @@ sealed trait Matrix extends Serializable { /** Number of columns. */ def numCols: Int + /** Flag that keeps track whether the matrix is transposed or not. False by default. */ + val isTransposed: Boolean = false + /** Converts to a dense array in column major. */ - def toArray: Array[Double] + def toArray: Array[Double] = { + val newArray = new Array[Double](numRows * numCols) + foreachActive { (i, j, v) => + newArray(j * numRows + i) = v + } + newArray + } /** Converts to a breeze matrix. */ private[mllib] def toBreeze: BM[Double] /** Gets the (i, j)-th element. */ - private[mllib] def apply(i: Int, j: Int): Double + def apply(i: Int, j: Int): Double /** Return the index for the (i, j)-th element in the backing array. */ private[mllib] def index(i: Int, j: Int): Int @@ -52,10 +61,13 @@ sealed trait Matrix extends Serializable { /** Get a deep copy of the matrix. */ def copy: Matrix + /** Transpose the Matrix. Returns a new `Matrix` instance sharing the same underlying data. */ + def transpose: Matrix + /** Convenience method for `Matrix`-`DenseMatrix` multiplication. */ def multiply(y: DenseMatrix): DenseMatrix = { - val C: DenseMatrix = Matrices.zeros(numRows, y.numCols).asInstanceOf[DenseMatrix] - BLAS.gemm(false, false, 1.0, this, y, 0.0, C) + val C: DenseMatrix = DenseMatrix.zeros(numRows, y.numCols) + BLAS.gemm(1.0, this, y, 0.0, C) C } @@ -66,20 +78,6 @@ sealed trait Matrix extends Serializable { output } - /** Convenience method for `Matrix`^T^-`DenseMatrix` multiplication. */ - private[mllib] def transposeMultiply(y: DenseMatrix): DenseMatrix = { - val C: DenseMatrix = Matrices.zeros(numCols, y.numCols).asInstanceOf[DenseMatrix] - BLAS.gemm(true, false, 1.0, this, y, 0.0, C) - C - } - - /** Convenience method for `Matrix`^T^-`DenseVector` multiplication. */ - private[mllib] def transposeMultiply(y: DenseVector): DenseVector = { - val output = new DenseVector(new Array[Double](numCols)) - BLAS.gemv(true, 1.0, this, y, 0.0, output) - output - } - /** A human readable representation of the matrix */ override def toString: String = toBreeze.toString() @@ -92,6 +90,16 @@ sealed trait Matrix extends Serializable { * backing array. For example, an operation such as addition or subtraction will only be * performed on the non-zero values in a `SparseMatrix`. */ private[mllib] def update(f: Double => Double): Matrix + + /** + * Applies a function `f` to all the active elements of dense and sparse matrix. The ordering + * of the elements are not defined. + * + * @param f the function takes three parameters where the first two parameters are the row + * and column indices respectively with the type `Int`, and the final parameter is the + * corresponding value in the matrix with type `Double`. + */ + private[spark] def foreachActive(f: (Int, Int, Double) => Unit) } /** @@ -107,14 +115,36 @@ sealed trait Matrix extends Serializable { * * @param numRows number of rows * @param numCols number of columns - * @param values matrix entries in column major + * @param values matrix entries in column major if not transposed or in row major otherwise + * @param isTransposed whether the matrix is transposed. If true, `values` stores the matrix in + * row major. */ -class DenseMatrix(val numRows: Int, val numCols: Int, val values: Array[Double]) extends Matrix { +class DenseMatrix( + val numRows: Int, + val numCols: Int, + val values: Array[Double], + override val isTransposed: Boolean) extends Matrix { require(values.length == numRows * numCols, "The number of values supplied doesn't match the " + s"size of the matrix! values.length: ${values.length}, numRows * numCols: ${numRows * numCols}") - override def toArray: Array[Double] = values + /** + * Column-major dense matrix. + * The entry values are stored in a single array of doubles with columns listed in sequence. + * For example, the following matrix + * {{{ + * 1.0 2.0 + * 3.0 4.0 + * 5.0 6.0 + * }}} + * is stored as `[1.0, 3.0, 5.0, 2.0, 4.0, 6.0]`. + * + * @param numRows number of rows + * @param numCols number of columns + * @param values matrix entries in column major + */ + def this(numRows: Int, numCols: Int, values: Array[Double]) = + this(numRows, numCols, values, false) override def equals(o: Any) = o match { case m: DenseMatrix => @@ -122,13 +152,22 @@ class DenseMatrix(val numRows: Int, val numCols: Int, val values: Array[Double]) case _ => false } - private[mllib] def toBreeze: BM[Double] = new BDM[Double](numRows, numCols, values) + private[mllib] def toBreeze: BM[Double] = { + if (!isTransposed) { + new BDM[Double](numRows, numCols, values) + } else { + val breezeMatrix = new BDM[Double](numCols, numRows, values) + breezeMatrix.t + } + } private[mllib] def apply(i: Int): Double = values(i) - private[mllib] def apply(i: Int, j: Int): Double = values(index(i, j)) + override def apply(i: Int, j: Int): Double = values(index(i, j)) - private[mllib] def index(i: Int, j: Int): Int = i + numRows * j + private[mllib] def index(i: Int, j: Int): Int = { + if (!isTransposed) i + numRows * j else j + numCols * i + } private[mllib] def update(i: Int, j: Int, v: Double): Unit = { values(index(i, j)) = v @@ -148,8 +187,41 @@ class DenseMatrix(val numRows: Int, val numCols: Int, val values: Array[Double]) this } - /** Generate a `SparseMatrix` from the given `DenseMatrix`. */ - def toSparse(): SparseMatrix = { + override def transpose: DenseMatrix = new DenseMatrix(numCols, numRows, values, !isTransposed) + + private[spark] override def foreachActive(f: (Int, Int, Double) => Unit): Unit = { + if (!isTransposed) { + // outer loop over columns + var j = 0 + while (j < numCols) { + var i = 0 + val indStart = j * numRows + while (i < numRows) { + f(i, j, values(indStart + i)) + i += 1 + } + j += 1 + } + } else { + // outer loop over rows + var i = 0 + while (i < numRows) { + var j = 0 + val indStart = i * numCols + while (j < numCols) { + f(i, j, values(indStart + j)) + j += 1 + } + i += 1 + } + } + } + + /** + * Generate a `SparseMatrix` from the given `DenseMatrix`. The new matrix will have isTransposed + * set to false. + */ + def toSparse: SparseMatrix = { val spVals: MArrayBuilder[Double] = new MArrayBuilder.ofDouble val colPtrs: Array[Int] = new Array[Int](numCols + 1) val rowIndices: MArrayBuilder[Int] = new MArrayBuilder.ofInt @@ -157,9 +229,8 @@ class DenseMatrix(val numRows: Int, val numCols: Int, val values: Array[Double]) var j = 0 while (j < numCols) { var i = 0 - val indStart = j * numRows while (i < numRows) { - val v = values(indStart + i) + val v = values(index(i, j)) if (v != 0.0) { rowIndices += i spVals += v @@ -185,8 +256,11 @@ object DenseMatrix { * @param numCols number of columns of the matrix * @return `DenseMatrix` with size `numRows` x `numCols` and values of zeros */ - def zeros(numRows: Int, numCols: Int): DenseMatrix = + def zeros(numRows: Int, numCols: Int): DenseMatrix = { + require(numRows.toLong * numCols <= Int.MaxValue, + s"$numRows x $numCols dense matrix is too large to allocate") new DenseMatrix(numRows, numCols, new Array[Double](numRows * numCols)) + } /** * Generate a `DenseMatrix` consisting of ones. @@ -194,8 +268,11 @@ object DenseMatrix { * @param numCols number of columns of the matrix * @return `DenseMatrix` with size `numRows` x `numCols` and values of ones */ - def ones(numRows: Int, numCols: Int): DenseMatrix = + def ones(numRows: Int, numCols: Int): DenseMatrix = { + require(numRows.toLong * numCols <= Int.MaxValue, + s"$numRows x $numCols dense matrix is too large to allocate") new DenseMatrix(numRows, numCols, Array.fill(numRows * numCols)(1.0)) + } /** * Generate an Identity Matrix in `DenseMatrix` format. @@ -213,24 +290,28 @@ object DenseMatrix { } /** - * Generate a `DenseMatrix` consisting of i.i.d. uniform random numbers. + * Generate a `DenseMatrix` consisting of `i.i.d.` uniform random numbers. * @param numRows number of rows of the matrix * @param numCols number of columns of the matrix * @param rng a random number generator * @return `DenseMatrix` with size `numRows` x `numCols` and values in U(0, 1) */ def rand(numRows: Int, numCols: Int, rng: Random): DenseMatrix = { + require(numRows.toLong * numCols <= Int.MaxValue, + s"$numRows x $numCols dense matrix is too large to allocate") new DenseMatrix(numRows, numCols, Array.fill(numRows * numCols)(rng.nextDouble())) } /** - * Generate a `DenseMatrix` consisting of i.i.d. gaussian random numbers. + * Generate a `DenseMatrix` consisting of `i.i.d.` gaussian random numbers. * @param numRows number of rows of the matrix * @param numCols number of columns of the matrix * @param rng a random number generator * @return `DenseMatrix` with size `numRows` x `numCols` and values in N(0, 1) */ def randn(numRows: Int, numCols: Int, rng: Random): DenseMatrix = { + require(numRows.toLong * numCols <= Int.MaxValue, + s"$numRows x $numCols dense matrix is too large to allocate") new DenseMatrix(numRows, numCols, Array.fill(numRows * numCols)(rng.nextGaussian())) } @@ -267,53 +348,77 @@ object DenseMatrix { * * @param numRows number of rows * @param numCols number of columns - * @param colPtrs the index corresponding to the start of a new column - * @param rowIndices the row index of the entry. They must be in strictly increasing order for each - * column - * @param values non-zero matrix entries in column major + * @param colPtrs the index corresponding to the start of a new column (if not transposed) + * @param rowIndices the row index of the entry (if not transposed). They must be in strictly + * increasing order for each column + * @param values nonzero matrix entries in column major (if not transposed) + * @param isTransposed whether the matrix is transposed. If true, the matrix can be considered + * Compressed Sparse Row (CSR) format, where `colPtrs` behaves as rowPtrs, + * and `rowIndices` behave as colIndices, and `values` are stored in row major. */ class SparseMatrix( val numRows: Int, val numCols: Int, val colPtrs: Array[Int], val rowIndices: Array[Int], - val values: Array[Double]) extends Matrix { + val values: Array[Double], + override val isTransposed: Boolean) extends Matrix { require(values.length == rowIndices.length, "The number of row indices and values don't match! " + s"values.length: ${values.length}, rowIndices.length: ${rowIndices.length}") - require(colPtrs.length == numCols + 1, "The length of the column indices should be the " + - s"number of columns + 1. Currently, colPointers.length: ${colPtrs.length}, " + - s"numCols: $numCols") + // The Or statement is for the case when the matrix is transposed + require(colPtrs.length == numCols + 1 || colPtrs.length == numRows + 1, "The length of the " + + "column indices should be the number of columns + 1. Currently, colPointers.length: " + + s"${colPtrs.length}, numCols: $numCols") require(values.length == colPtrs.last, "The last value of colPtrs must equal the number of " + s"elements. values.length: ${values.length}, colPtrs.last: ${colPtrs.last}") - override def toArray: Array[Double] = { - val arr = new Array[Double](numRows * numCols) - var j = 0 - while (j < numCols) { - var i = colPtrs(j) - val indEnd = colPtrs(j + 1) - val offset = j * numRows - while (i < indEnd) { - val rowIndex = rowIndices(i) - arr(offset + rowIndex) = values(i) - i += 1 - } - j += 1 - } - arr + /** + * Column-major sparse matrix. + * The entry values are stored in Compressed Sparse Column (CSC) format. + * For example, the following matrix + * {{{ + * 1.0 0.0 4.0 + * 0.0 3.0 5.0 + * 2.0 0.0 6.0 + * }}} + * is stored as `values: [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]`, + * `rowIndices=[0, 2, 1, 0, 1, 2]`, `colPointers=[0, 2, 3, 6]`. + * + * @param numRows number of rows + * @param numCols number of columns + * @param colPtrs the index corresponding to the start of a new column + * @param rowIndices the row index of the entry. They must be in strictly increasing + * order for each column + * @param values non-zero matrix entries in column major + */ + def this( + numRows: Int, + numCols: Int, + colPtrs: Array[Int], + rowIndices: Array[Int], + values: Array[Double]) = this(numRows, numCols, colPtrs, rowIndices, values, false) + + private[mllib] def toBreeze: BM[Double] = { + if (!isTransposed) { + new BSM[Double](values, numRows, numCols, colPtrs, rowIndices) + } else { + val breezeMatrix = new BSM[Double](values, numCols, numRows, colPtrs, rowIndices) + breezeMatrix.t + } } - private[mllib] def toBreeze: BM[Double] = - new BSM[Double](values, numRows, numCols, colPtrs, rowIndices) - - private[mllib] def apply(i: Int, j: Int): Double = { + override def apply(i: Int, j: Int): Double = { val ind = index(i, j) if (ind < 0) 0.0 else values(ind) } private[mllib] def index(i: Int, j: Int): Int = { - Arrays.binarySearch(rowIndices, colPtrs(j), colPtrs(j + 1), i) + if (!isTransposed) { + Arrays.binarySearch(rowIndices, colPtrs(j), colPtrs(j + 1), i) + } else { + Arrays.binarySearch(rowIndices, colPtrs(i), colPtrs(i + 1), j) + } } private[mllib] def update(i: Int, j: Int, v: Double): Unit = { @@ -322,7 +427,7 @@ class SparseMatrix( throw new NoSuchElementException("The given row and column indices correspond to a zero " + "value. Only non-zero elements in Sparse Matrices can be updated.") } else { - values(index(i, j)) = v + values(ind) = v } } @@ -341,8 +446,41 @@ class SparseMatrix( this } - /** Generate a `DenseMatrix` from the given `SparseMatrix`. */ - def toDense(): DenseMatrix = { + override def transpose: SparseMatrix = + new SparseMatrix(numCols, numRows, colPtrs, rowIndices, values, !isTransposed) + + private[spark] override def foreachActive(f: (Int, Int, Double) => Unit): Unit = { + if (!isTransposed) { + var j = 0 + while (j < numCols) { + var idx = colPtrs(j) + val idxEnd = colPtrs(j + 1) + while (idx < idxEnd) { + f(rowIndices(idx), j, values(idx)) + idx += 1 + } + j += 1 + } + } else { + var i = 0 + while (i < numRows) { + var idx = colPtrs(i) + val idxEnd = colPtrs(i + 1) + while (idx < idxEnd) { + val j = rowIndices(idx) + f(i, j, values(idx)) + idx += 1 + } + i += 1 + } + } + } + + /** + * Generate a `DenseMatrix` from the given `SparseMatrix`. The new matrix will have isTransposed + * set to false. + */ + def toDense: DenseMatrix = { new DenseMatrix(numRows, numCols, toArray) } } @@ -469,7 +607,7 @@ object SparseMatrix { } /** - * Generate a `SparseMatrix` consisting of i.i.d. uniform random numbers. The number of non-zero + * Generate a `SparseMatrix` consisting of `i.i.d`. uniform random numbers. The number of non-zero * elements equal the ceiling of `numRows` x `numCols` x `density` * * @param numRows number of rows of the matrix @@ -484,7 +622,7 @@ object SparseMatrix { } /** - * Generate a `SparseMatrix` consisting of i.i.d. gaussian random numbers. + * Generate a `SparseMatrix` consisting of `i.i.d`. gaussian random numbers. * @param numRows number of rows of the matrix * @param numCols number of columns of the matrix * @param density the desired density for the matrix @@ -502,7 +640,7 @@ object SparseMatrix { * @return Square `SparseMatrix` with size `values.length` x `values.length` and non-zero * `values` on the diagonal */ - def diag(vector: Vector): SparseMatrix = { + def spdiag(vector: Vector): SparseMatrix = { val n = vector.size vector match { case sVec: SparseVector => @@ -557,10 +695,9 @@ object Matrices { private[mllib] def fromBreeze(breeze: BM[Double]): Matrix = { breeze match { case dm: BDM[Double] => - require(dm.majorStride == dm.rows, - "Do not support stride size different from the number of rows.") - new DenseMatrix(dm.rows, dm.cols, dm.data) + new DenseMatrix(dm.rows, dm.cols, dm.data, dm.isTranspose) case sm: BSM[Double] => + // There is no isTranspose flag for sparse matrices in Breeze new SparseMatrix(sm.rows, sm.cols, sm.colPtrs, sm.rowIndices, sm.data) case _ => throw new UnsupportedOperationException( @@ -569,7 +706,7 @@ object Matrices { } /** - * Generate a `DenseMatrix` consisting of zeros. + * Generate a `Matrix` consisting of zeros. * @param numRows number of rows of the matrix * @param numCols number of columns of the matrix * @return `Matrix` with size `numRows` x `numCols` and values of zeros @@ -599,7 +736,7 @@ object Matrices { def speye(n: Int): Matrix = SparseMatrix.speye(n) /** - * Generate a `DenseMatrix` consisting of i.i.d. uniform random numbers. + * Generate a `DenseMatrix` consisting of `i.i.d.` uniform random numbers. * @param numRows number of rows of the matrix * @param numCols number of columns of the matrix * @param rng a random number generator @@ -609,7 +746,7 @@ object Matrices { DenseMatrix.rand(numRows, numCols, rng) /** - * Generate a `SparseMatrix` consisting of i.i.d. gaussian random numbers. + * Generate a `SparseMatrix` consisting of `i.i.d.` gaussian random numbers. * @param numRows number of rows of the matrix * @param numCols number of columns of the matrix * @param density the desired density for the matrix @@ -620,7 +757,7 @@ object Matrices { SparseMatrix.sprand(numRows, numCols, density, rng) /** - * Generate a `DenseMatrix` consisting of i.i.d. gaussian random numbers. + * Generate a `DenseMatrix` consisting of `i.i.d.` gaussian random numbers. * @param numRows number of rows of the matrix * @param numCols number of columns of the matrix * @param rng a random number generator @@ -630,7 +767,7 @@ object Matrices { DenseMatrix.randn(numRows, numCols, rng) /** - * Generate a `SparseMatrix` consisting of i.i.d. gaussian random numbers. + * Generate a `SparseMatrix` consisting of `i.i.d.` gaussian random numbers. * @param numRows number of rows of the matrix * @param numCols number of columns of the matrix * @param density the desired density for the matrix @@ -641,8 +778,8 @@ object Matrices { SparseMatrix.sprandn(numRows, numCols, density, rng) /** - * Generate a diagonal matrix in `DenseMatrix` format from the supplied values. - * @param vector a `Vector` tat will form the values on the diagonal of the matrix + * Generate a diagonal matrix in `Matrix` format from the supplied values. + * @param vector a `Vector` that will form the values on the diagonal of the matrix * @return Square `Matrix` with size `values.length` x `values.length` and `values` * on the diagonal */ @@ -679,46 +816,28 @@ object Matrices { new DenseMatrix(numRows, numCols, matrices.flatMap(_.toArray)) } else { var startCol = 0 - val entries: Array[(Int, Int, Double)] = matrices.flatMap { - case spMat: SparseMatrix => - var j = 0 - val colPtrs = spMat.colPtrs - val rowIndices = spMat.rowIndices - val values = spMat.values - val data = new Array[(Int, Int, Double)](values.length) - val nCols = spMat.numCols - while (j < nCols) { - var idx = colPtrs(j) - while (idx < colPtrs(j + 1)) { - val i = rowIndices(idx) - val v = values(idx) - data(idx) = (i, j + startCol, v) - idx += 1 + val entries: Array[(Int, Int, Double)] = matrices.flatMap { mat => + val nCols = mat.numCols + mat match { + case spMat: SparseMatrix => + val data = new Array[(Int, Int, Double)](spMat.values.length) + var cnt = 0 + spMat.foreachActive { (i, j, v) => + data(cnt) = (i, j + startCol, v) + cnt += 1 } - j += 1 - } - startCol += nCols - data - case dnMat: DenseMatrix => - val data = new ArrayBuffer[(Int, Int, Double)]() - var j = 0 - val nCols = dnMat.numCols - val nRows = dnMat.numRows - val values = dnMat.values - while (j < nCols) { - var i = 0 - val indStart = j * nRows - while (i < nRows) { - val v = values(indStart + i) + startCol += nCols + data + case dnMat: DenseMatrix => + val data = new ArrayBuffer[(Int, Int, Double)]() + dnMat.foreachActive { (i, j, v) => if (v != 0.0) { data.append((i, j + startCol, v)) } - i += 1 } - j += 1 - } - startCol += nCols - data + startCol += nCols + data + } } SparseMatrix.fromCOO(numRows, numCols, entries) } @@ -744,14 +863,12 @@ object Matrices { require(numCols == mat.numCols, "The number of rows of the matrices in this sequence, " + "don't match!") mat match { - case sparse: SparseMatrix => - hasSparse = true - case dense: DenseMatrix => + case sparse: SparseMatrix => hasSparse = true + case dense: DenseMatrix => // empty on purpose case _ => throw new IllegalArgumentException("Unsupported matrix format. Expected " + s"SparseMatrix or DenseMatrix. Instead got: ${mat.getClass}") } numRows += mat.numRows - } if (!hasSparse) { val allValues = new Array[Double](numRows * numCols) @@ -759,61 +876,37 @@ object Matrices { matrices.foreach { mat => var j = 0 val nRows = mat.numRows - val values = mat.toArray - while (j < numCols) { - var i = 0 + mat.foreachActive { (i, j, v) => val indStart = j * numRows + startRow - val subMatStart = j * nRows - while (i < nRows) { - allValues(indStart + i) = values(subMatStart + i) - i += 1 - } - j += 1 + allValues(indStart + i) = v } startRow += nRows } new DenseMatrix(numRows, numCols, allValues) } else { var startRow = 0 - val entries: Array[(Int, Int, Double)] = matrices.flatMap { - case spMat: SparseMatrix => - var j = 0 - val colPtrs = spMat.colPtrs - val rowIndices = spMat.rowIndices - val values = spMat.values - val data = new Array[(Int, Int, Double)](values.length) - while (j < numCols) { - var idx = colPtrs(j) - while (idx < colPtrs(j + 1)) { - val i = rowIndices(idx) - val v = values(idx) - data(idx) = (i + startRow, j, v) - idx += 1 + val entries: Array[(Int, Int, Double)] = matrices.flatMap { mat => + val nRows = mat.numRows + mat match { + case spMat: SparseMatrix => + val data = new Array[(Int, Int, Double)](spMat.values.length) + var cnt = 0 + spMat.foreachActive { (i, j, v) => + data(cnt) = (i + startRow, j, v) + cnt += 1 } - j += 1 - } - startRow += spMat.numRows - data - case dnMat: DenseMatrix => - val data = new ArrayBuffer[(Int, Int, Double)]() - var j = 0 - val nCols = dnMat.numCols - val nRows = dnMat.numRows - val values = dnMat.values - while (j < nCols) { - var i = 0 - val indStart = j * nRows - while (i < nRows) { - val v = values(indStart + i) + startRow += nRows + data + case dnMat: DenseMatrix => + val data = new ArrayBuffer[(Int, Int, Double)]() + dnMat.foreachActive { (i, j, v) => if (v != 0.0) { data.append((i + startRow, j, v)) } - i += 1 } - j += 1 - } - startRow += nRows - data + startRow += nRows + data + } } SparseMatrix.fromCOO(numRows, numCols, entries) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index adbd8266ed6fa..4bdcb283da09c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -26,8 +26,10 @@ import scala.collection.JavaConverters._ import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, Vector => BV} import org.apache.spark.SparkException +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.mllib.util.NumericParser -import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, Row} +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.expressions.GenericMutableRow import org.apache.spark.sql.types._ /** @@ -50,13 +52,35 @@ sealed trait Vector extends Serializable { override def equals(other: Any): Boolean = { other match { - case v: Vector => - util.Arrays.equals(this.toArray, v.toArray) + case v2: Vector => { + if (this.size != v2.size) return false + (this, v2) match { + case (s1: SparseVector, s2: SparseVector) => + Vectors.equals(s1.indices, s1.values, s2.indices, s2.values) + case (s1: SparseVector, d1: DenseVector) => + Vectors.equals(s1.indices, s1.values, 0 until d1.size, d1.values) + case (d1: DenseVector, s1: SparseVector) => + Vectors.equals(0 until d1.size, d1.values, s1.indices, s1.values) + case (_, _) => util.Arrays.equals(this.toArray, v2.toArray) + } + } case _ => false } } - override def hashCode(): Int = util.Arrays.hashCode(this.toArray) + override def hashCode(): Int = { + var result: Int = size + 31 + this.foreachActive { case (index, value) => + // ignore explict 0 for comparison between sparse and dense + if (value != 0) { + result = 31 * result + index + // refer to {@link java.util.Arrays.equals} for hash algorithm + val bits = java.lang.Double.doubleToLongBits(value) + result = 31 * result + (bits ^ (bits >>> 32)).toInt + } + } + result + } /** * Converts the instance to a breeze vector. @@ -87,9 +111,14 @@ sealed trait Vector extends Serializable { } /** + * :: DeveloperApi :: + * * User-defined type for [[Vector]] which allows easy interaction with SQL - * via [[org.apache.spark.sql.SchemaRDD]]. + * via [[org.apache.spark.sql.DataFrame]]. + * + * NOTE: This is currently private[spark] but will be made public later once it is stabilized. */ +@DeveloperApi private[spark] class VectorUDT extends UserDefinedType[Vector] { override def sqlType: StructType = { @@ -146,6 +175,13 @@ private[spark] class VectorUDT extends UserDefinedType[Vector] { override def pyUDT: String = "pyspark.mllib.linalg.VectorUDT" override def userClass: Class[Vector] = classOf[Vector] + + override def equals(o: Any): Boolean = { + o match { + case v: VectorUDT => true + case _ => false + } + } } /** @@ -211,7 +247,7 @@ object Vectors { } /** - * Creates a dense vector of all zeros. + * Creates a vector of all zeros. * * @param size vector size * @return a zero vector @@ -221,8 +257,7 @@ object Vectors { } /** - * Parses a string resulted from `Vector#toString` into - * an [[org.apache.spark.mllib.linalg.Vector]]. + * Parses a string resulted from [[Vector.toString]] into a [[Vector]]. */ def parse(s: String): Vector = { parseNumeric(NumericParser.parse(s)) @@ -311,7 +346,7 @@ object Vectors { math.pow(sum, 1.0 / p) } } - + /** * Returns the squared distance between two Vectors. * @param v1 first Vector. @@ -319,8 +354,9 @@ object Vectors { * @return squared distance between two Vectors. */ def sqdist(v1: Vector, v2: Vector): Double = { + require(v1.size == v2.size, "vector dimension mismatch") var squaredDistance = 0.0 - (v1, v2) match { + (v1, v2) match { case (v1: SparseVector, v2: SparseVector) => val v1Values = v1.values val v1Indices = v1.indices @@ -328,12 +364,12 @@ object Vectors { val v2Indices = v2.indices val nnzv1 = v1Indices.size val nnzv2 = v2Indices.size - + var kv1 = 0 var kv2 = 0 while (kv1 < nnzv1 || kv2 < nnzv2) { var score = 0.0 - + if (kv2 >= nnzv2 || (kv1 < nnzv1 && v1Indices(kv1) < v2Indices(kv2))) { score = v1Values(kv1) kv1 += 1 @@ -348,18 +384,23 @@ object Vectors { squaredDistance += score * score } - case (v1: SparseVector, v2: DenseVector) if v1.indices.length / v1.size < 0.5 => + case (v1: SparseVector, v2: DenseVector) => squaredDistance = sqdist(v1, v2) - case (v1: DenseVector, v2: SparseVector) if v2.indices.length / v2.size < 0.5 => + case (v1: DenseVector, v2: SparseVector) => squaredDistance = sqdist(v2, v1) - // When a SparseVector is approximately dense, we treat it as a DenseVector - case (v1, v2) => - squaredDistance = v1.toArray.zip(v2.toArray).foldLeft(0.0){ (distance, elems) => - val score = elems._1 - elems._2 - distance + score * score + case (DenseVector(vv1), DenseVector(vv2)) => + var kv = 0 + val sz = vv1.size + while (kv < sz) { + val score = vv1(kv) - vv2(kv) + squaredDistance += score * score + kv += 1 } + case _ => + throw new IllegalArgumentException("Do not support vector type " + v1.getClass + + " and " + v2.getClass) } squaredDistance } @@ -375,7 +416,7 @@ object Vectors { val nnzv1 = indices.size val nnzv2 = v2.size var iv1 = if (nnzv1 > 0) indices(kv1) else -1 - + while (kv2 < nnzv2) { var score = 0.0 if (kv2 != iv1) { @@ -392,6 +433,33 @@ object Vectors { } squaredDistance } + + /** + * Check equality between sparse/dense vectors + */ + private[mllib] def equals( + v1Indices: IndexedSeq[Int], + v1Values: Array[Double], + v2Indices: IndexedSeq[Int], + v2Values: Array[Double]): Boolean = { + val v1Size = v1Values.size + val v2Size = v2Values.size + var k1 = 0 + var k2 = 0 + var allEqual = true + while (allEqual) { + while (k1 < v1Size && v1Values(k1) == 0) k1 += 1 + while (k2 < v2Size && v2Values(k2) == 0) k2 += 1 + + if (k1 >= v1Size || k2 >= v2Size) { + return k1 >= v1Size && k2 >= v2Size // check end alignment + } + allEqual = v1Indices(k1) == v2Indices(k2) && v1Values(k1) == v2Values(k2) + k1 += 1 + k2 += 1 + } + allEqual + } } /** @@ -427,6 +495,7 @@ class DenseVector(val values: Array[Double]) extends Vector { } object DenseVector { + /** Extracts the value array from a dense vector. */ def unapply(dv: DenseVector): Option[Array[Double]] = Some(dv.values) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala new file mode 100644 index 0000000000000..1d253963130f1 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala @@ -0,0 +1,378 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.linalg.distributed + +import scala.collection.mutable.ArrayBuffer + +import breeze.linalg.{DenseMatrix => BDM} + +import org.apache.spark.{Logging, Partitioner, SparkException} +import org.apache.spark.annotation.Experimental +import org.apache.spark.mllib.linalg.{DenseMatrix, Matrices, Matrix, SparseMatrix} +import org.apache.spark.rdd.RDD +import org.apache.spark.storage.StorageLevel + +/** + * A grid partitioner, which uses a regular grid to partition coordinates. + * + * @param rows Number of rows. + * @param cols Number of columns. + * @param rowsPerPart Number of rows per partition, which may be less at the bottom edge. + * @param colsPerPart Number of columns per partition, which may be less at the right edge. + */ +private[mllib] class GridPartitioner( + val rows: Int, + val cols: Int, + val rowsPerPart: Int, + val colsPerPart: Int) extends Partitioner { + + require(rows > 0) + require(cols > 0) + require(rowsPerPart > 0) + require(colsPerPart > 0) + + private val rowPartitions = math.ceil(rows * 1.0 / rowsPerPart).toInt + private val colPartitions = math.ceil(cols * 1.0 / colsPerPart).toInt + + override val numPartitions = rowPartitions * colPartitions + + /** + * Returns the index of the partition the input coordinate belongs to. + * + * @param key The coordinate (i, j) or a tuple (i, j, k), where k is the inner index used in + * multiplication. k is ignored in computing partitions. + * @return The index of the partition, which the coordinate belongs to. + */ + override def getPartition(key: Any): Int = { + key match { + case (i: Int, j: Int) => + getPartitionId(i, j) + case (i: Int, j: Int, _: Int) => + getPartitionId(i, j) + case _ => + throw new IllegalArgumentException(s"Unrecognized key: $key.") + } + } + + /** Partitions sub-matrices as blocks with neighboring sub-matrices. */ + private def getPartitionId(i: Int, j: Int): Int = { + require(0 <= i && i < rows, s"Row index $i out of range [0, $rows).") + require(0 <= j && j < cols, s"Column index $j out of range [0, $cols).") + i / rowsPerPart + j / colsPerPart * rowPartitions + } + + override def equals(obj: Any): Boolean = { + obj match { + case r: GridPartitioner => + (this.rows == r.rows) && (this.cols == r.cols) && + (this.rowsPerPart == r.rowsPerPart) && (this.colsPerPart == r.colsPerPart) + case _ => + false + } + } +} + +private[mllib] object GridPartitioner { + + /** Creates a new [[GridPartitioner]] instance. */ + def apply(rows: Int, cols: Int, rowsPerPart: Int, colsPerPart: Int): GridPartitioner = { + new GridPartitioner(rows, cols, rowsPerPart, colsPerPart) + } + + /** Creates a new [[GridPartitioner]] instance with the input suggested number of partitions. */ + def apply(rows: Int, cols: Int, suggestedNumPartitions: Int): GridPartitioner = { + require(suggestedNumPartitions > 0) + val scale = 1.0 / math.sqrt(suggestedNumPartitions) + val rowsPerPart = math.round(math.max(scale * rows, 1.0)).toInt + val colsPerPart = math.round(math.max(scale * cols, 1.0)).toInt + new GridPartitioner(rows, cols, rowsPerPart, colsPerPart) + } +} + +/** + * :: Experimental :: + * + * Represents a distributed matrix in blocks of local matrices. + * + * @param blocks The RDD of sub-matrix blocks ((blockRowIndex, blockColIndex), sub-matrix) that + * form this distributed matrix. If multiple blocks with the same index exist, the + * results for operations like add and multiply will be unpredictable. + * @param rowsPerBlock Number of rows that make up each block. The blocks forming the final + * rows are not required to have the given number of rows + * @param colsPerBlock Number of columns that make up each block. The blocks forming the final + * columns are not required to have the given number of columns + * @param nRows Number of rows of this matrix. If the supplied value is less than or equal to zero, + * the number of rows will be calculated when `numRows` is invoked. + * @param nCols Number of columns of this matrix. If the supplied value is less than or equal to + * zero, the number of columns will be calculated when `numCols` is invoked. + */ +@Experimental +class BlockMatrix( + val blocks: RDD[((Int, Int), Matrix)], + val rowsPerBlock: Int, + val colsPerBlock: Int, + private var nRows: Long, + private var nCols: Long) extends DistributedMatrix with Logging { + + private type MatrixBlock = ((Int, Int), Matrix) // ((blockRowIndex, blockColIndex), sub-matrix) + + /** + * Alternate constructor for BlockMatrix without the input of the number of rows and columns. + * + * @param blocks The RDD of sub-matrix blocks ((blockRowIndex, blockColIndex), sub-matrix) that + * form this distributed matrix. If multiple blocks with the same index exist, the + * results for operations like add and multiply will be unpredictable. + * @param rowsPerBlock Number of rows that make up each block. The blocks forming the final + * rows are not required to have the given number of rows + * @param colsPerBlock Number of columns that make up each block. The blocks forming the final + * columns are not required to have the given number of columns + */ + def this( + blocks: RDD[((Int, Int), Matrix)], + rowsPerBlock: Int, + colsPerBlock: Int) = { + this(blocks, rowsPerBlock, colsPerBlock, 0L, 0L) + } + + override def numRows(): Long = { + if (nRows <= 0L) estimateDim() + nRows + } + + override def numCols(): Long = { + if (nCols <= 0L) estimateDim() + nCols + } + + val numRowBlocks = math.ceil(numRows() * 1.0 / rowsPerBlock).toInt + val numColBlocks = math.ceil(numCols() * 1.0 / colsPerBlock).toInt + + private[mllib] def createPartitioner(): GridPartitioner = + GridPartitioner(numRowBlocks, numColBlocks, suggestedNumPartitions = blocks.partitions.size) + + private lazy val blockInfo = blocks.mapValues(block => (block.numRows, block.numCols)).cache() + + /** Estimates the dimensions of the matrix. */ + private def estimateDim(): Unit = { + val (rows, cols) = blockInfo.map { case ((blockRowIndex, blockColIndex), (m, n)) => + (blockRowIndex.toLong * rowsPerBlock + m, + blockColIndex.toLong * colsPerBlock + n) + }.reduce { (x0, x1) => + (math.max(x0._1, x1._1), math.max(x0._2, x1._2)) + } + if (nRows <= 0L) nRows = rows + assert(rows <= nRows, s"The number of rows $rows is more than claimed $nRows.") + if (nCols <= 0L) nCols = cols + assert(cols <= nCols, s"The number of columns $cols is more than claimed $nCols.") + } + + /** + * Validates the block matrix info against the matrix data (`blocks`) and throws an exception if + * any error is found. + */ + def validate(): Unit = { + logDebug("Validating BlockMatrix...") + // check if the matrix is larger than the claimed dimensions + estimateDim() + logDebug("BlockMatrix dimensions are okay...") + + // Check if there are multiple MatrixBlocks with the same index. + blockInfo.countByKey().foreach { case (key, cnt) => + if (cnt > 1) { + throw new SparkException(s"Found multiple MatrixBlocks with the indices $key. Please " + + "remove blocks with duplicate indices.") + } + } + logDebug("MatrixBlock indices are okay...") + // Check if each MatrixBlock (except edges) has the dimensions rowsPerBlock x colsPerBlock + // The first tuple is the index and the second tuple is the dimensions of the MatrixBlock + val dimensionMsg = s"dimensions different than rowsPerBlock: $rowsPerBlock, and " + + s"colsPerBlock: $colsPerBlock. Blocks on the right and bottom edges can have smaller " + + s"dimensions. You may use the repartition method to fix this issue." + blockInfo.foreach { case ((blockRowIndex, blockColIndex), (m, n)) => + if ((blockRowIndex < numRowBlocks - 1 && m != rowsPerBlock) || + (blockRowIndex == numRowBlocks - 1 && (m <= 0 || m > rowsPerBlock))) { + throw new SparkException(s"The MatrixBlock at ($blockRowIndex, $blockColIndex) has " + + dimensionMsg) + } + if ((blockColIndex < numColBlocks - 1 && n != colsPerBlock) || + (blockColIndex == numColBlocks - 1 && (n <= 0 || n > colsPerBlock))) { + throw new SparkException(s"The MatrixBlock at ($blockRowIndex, $blockColIndex) has " + + dimensionMsg) + } + } + logDebug("MatrixBlock dimensions are okay...") + logDebug("BlockMatrix is valid!") + } + + /** Caches the underlying RDD. */ + def cache(): this.type = { + blocks.cache() + this + } + + /** Persists the underlying RDD with the specified storage level. */ + def persist(storageLevel: StorageLevel): this.type = { + blocks.persist(storageLevel) + this + } + + /** Converts to CoordinateMatrix. */ + def toCoordinateMatrix(): CoordinateMatrix = { + val entryRDD = blocks.flatMap { case ((blockRowIndex, blockColIndex), mat) => + val rowStart = blockRowIndex.toLong * rowsPerBlock + val colStart = blockColIndex.toLong * colsPerBlock + val entryValues = new ArrayBuffer[MatrixEntry]() + mat.foreachActive { (i, j, v) => + if (v != 0.0) entryValues.append(new MatrixEntry(rowStart + i, colStart + j, v)) + } + entryValues + } + new CoordinateMatrix(entryRDD, numRows(), numCols()) + } + + /** Converts to IndexedRowMatrix. The number of columns must be within the integer range. */ + def toIndexedRowMatrix(): IndexedRowMatrix = { + require(numCols() < Int.MaxValue, "The number of columns must be within the integer range. " + + s"numCols: ${numCols()}") + // TODO: This implementation may be optimized + toCoordinateMatrix().toIndexedRowMatrix() + } + + /** Collect the distributed matrix on the driver as a `DenseMatrix`. */ + def toLocalMatrix(): Matrix = { + require(numRows() < Int.MaxValue, "The number of rows of this matrix should be less than " + + s"Int.MaxValue. Currently numRows: ${numRows()}") + require(numCols() < Int.MaxValue, "The number of columns of this matrix should be less than " + + s"Int.MaxValue. Currently numCols: ${numCols()}") + require(numRows() * numCols() < Int.MaxValue, "The length of the values array must be " + + s"less than Int.MaxValue. Currently numRows * numCols: ${numRows() * numCols()}") + val m = numRows().toInt + val n = numCols().toInt + val mem = m * n / 125000 + if (mem > 500) logWarning(s"Storing this matrix will require $mem MB of memory!") + val localBlocks = blocks.collect() + val values = new Array[Double](m * n) + localBlocks.foreach { case ((blockRowIndex, blockColIndex), submat) => + val rowOffset = blockRowIndex * rowsPerBlock + val colOffset = blockColIndex * colsPerBlock + submat.foreachActive { (i, j, v) => + val indexOffset = (j + colOffset) * m + rowOffset + i + values(indexOffset) = v + } + } + new DenseMatrix(m, n, values) + } + + /** Transpose this `BlockMatrix`. Returns a new `BlockMatrix` instance sharing the + * same underlying data. Is a lazy operation. */ + def transpose: BlockMatrix = { + val transposedBlocks = blocks.map { case ((blockRowIndex, blockColIndex), mat) => + ((blockColIndex, blockRowIndex), mat.transpose) + } + new BlockMatrix(transposedBlocks, colsPerBlock, rowsPerBlock, nCols, nRows) + } + + /** Collects data and assembles a local dense breeze matrix (for test only). */ + private[mllib] def toBreeze(): BDM[Double] = { + val localMat = toLocalMatrix() + new BDM[Double](localMat.numRows, localMat.numCols, localMat.toArray) + } + + /** Adds two block matrices together. The matrices must have the same size and matching + * `rowsPerBlock` and `colsPerBlock` values. If one of the blocks that are being added are + * instances of [[SparseMatrix]], the resulting sub matrix will also be a [[SparseMatrix]], even + * if it is being added to a [[DenseMatrix]]. If two dense matrices are added, the output will + * also be a [[DenseMatrix]]. + */ + def add(other: BlockMatrix): BlockMatrix = { + require(numRows() == other.numRows(), "Both matrices must have the same number of rows. " + + s"A.numRows: ${numRows()}, B.numRows: ${other.numRows()}") + require(numCols() == other.numCols(), "Both matrices must have the same number of columns. " + + s"A.numCols: ${numCols()}, B.numCols: ${other.numCols()}") + if (rowsPerBlock == other.rowsPerBlock && colsPerBlock == other.colsPerBlock) { + val addedBlocks = blocks.cogroup(other.blocks, createPartitioner()) + .map { case ((blockRowIndex, blockColIndex), (a, b)) => + if (a.size > 1 || b.size > 1) { + throw new SparkException("There are multiple MatrixBlocks with indices: " + + s"($blockRowIndex, $blockColIndex). Please remove them.") + } + if (a.isEmpty) { + new MatrixBlock((blockRowIndex, blockColIndex), b.head) + } else if (b.isEmpty) { + new MatrixBlock((blockRowIndex, blockColIndex), a.head) + } else { + val result = a.head.toBreeze + b.head.toBreeze + new MatrixBlock((blockRowIndex, blockColIndex), Matrices.fromBreeze(result)) + } + } + new BlockMatrix(addedBlocks, rowsPerBlock, colsPerBlock, numRows(), numCols()) + } else { + throw new SparkException("Cannot add matrices with different block dimensions") + } + } + + /** Left multiplies this [[BlockMatrix]] to `other`, another [[BlockMatrix]]. The `colsPerBlock` + * of this matrix must equal the `rowsPerBlock` of `other`. If `other` contains + * [[SparseMatrix]], they will have to be converted to a [[DenseMatrix]]. The output + * [[BlockMatrix]] will only consist of blocks of [[DenseMatrix]]. This may cause + * some performance issues until support for multiplying two sparse matrices is added. + */ + def multiply(other: BlockMatrix): BlockMatrix = { + require(numCols() == other.numRows(), "The number of columns of A and the number of rows " + + s"of B must be equal. A.numCols: ${numCols()}, B.numRows: ${other.numRows()}. If you " + + "think they should be equal, try setting the dimensions of A and B explicitly while " + + "initializing them.") + if (colsPerBlock == other.rowsPerBlock) { + val resultPartitioner = GridPartitioner(numRowBlocks, other.numColBlocks, + math.max(blocks.partitions.length, other.blocks.partitions.length)) + // Each block of A must be multiplied with the corresponding blocks in each column of B. + // TODO: Optimize to send block to a partition once, similar to ALS + val flatA = blocks.flatMap { case ((blockRowIndex, blockColIndex), block) => + Iterator.tabulate(other.numColBlocks)(j => ((blockRowIndex, j, blockColIndex), block)) + } + // Each block of B must be multiplied with the corresponding blocks in each row of A. + val flatB = other.blocks.flatMap { case ((blockRowIndex, blockColIndex), block) => + Iterator.tabulate(numRowBlocks)(i => ((i, blockColIndex, blockRowIndex), block)) + } + val newBlocks: RDD[MatrixBlock] = flatA.cogroup(flatB, resultPartitioner) + .flatMap { case ((blockRowIndex, blockColIndex, _), (a, b)) => + if (a.size > 1 || b.size > 1) { + throw new SparkException("There are multiple MatrixBlocks with indices: " + + s"($blockRowIndex, $blockColIndex). Please remove them.") + } + if (a.nonEmpty && b.nonEmpty) { + val C = b.head match { + case dense: DenseMatrix => a.head.multiply(dense) + case sparse: SparseMatrix => a.head.multiply(sparse.toDense) + case _ => throw new SparkException(s"Unrecognized matrix type ${b.head.getClass}.") + } + Iterator(((blockRowIndex, blockColIndex), C.toBreeze)) + } else { + Iterator() + } + }.reduceByKey(resultPartitioner, (a, b) => a + b) + .mapValues(Matrices.fromBreeze) + // TODO: Try to use aggregateByKey instead of reduceByKey to get rid of intermediate matrices + new BlockMatrix(newBlocks, rowsPerBlock, other.colsPerBlock, numRows(), other.numCols()) + } else { + throw new SparkException("colsPerBlock of A doesn't match rowsPerBlock of B. " + + s"A.colsPerBlock: $colsPerBlock, B.rowsPerBlock: ${other.rowsPerBlock}") + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrix.scala index 06d8915f3bfa1..078d1fac44443 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrix.scala @@ -21,8 +21,7 @@ import breeze.linalg.{DenseMatrix => BDM} import org.apache.spark.annotation.Experimental import org.apache.spark.rdd.RDD -import org.apache.spark.SparkContext._ -import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.linalg.{Matrix, SparseMatrix, Vectors} /** * :: Experimental :: @@ -69,6 +68,11 @@ class CoordinateMatrix( nRows } + /** Transposes this CoordinateMatrix. */ + def transpose(): CoordinateMatrix = { + new CoordinateMatrix(entries.map(x => MatrixEntry(x.j, x.i, x.value)), numCols(), numRows()) + } + /** Converts to IndexedRowMatrix. The number of columns must be within the integer range. */ def toIndexedRowMatrix(): IndexedRowMatrix = { val nl = numCols() @@ -93,6 +97,46 @@ class CoordinateMatrix( toIndexedRowMatrix().toRowMatrix() } + /** Converts to BlockMatrix. Creates blocks of [[SparseMatrix]] with size 1024 x 1024. */ + def toBlockMatrix(): BlockMatrix = { + toBlockMatrix(1024, 1024) + } + + /** + * Converts to BlockMatrix. Creates blocks of [[SparseMatrix]]. + * @param rowsPerBlock The number of rows of each block. The blocks at the bottom edge may have + * a smaller value. Must be an integer value greater than 0. + * @param colsPerBlock The number of columns of each block. The blocks at the right edge may have + * a smaller value. Must be an integer value greater than 0. + * @return a [[BlockMatrix]] + */ + def toBlockMatrix(rowsPerBlock: Int, colsPerBlock: Int): BlockMatrix = { + require(rowsPerBlock > 0, + s"rowsPerBlock needs to be greater than 0. rowsPerBlock: $rowsPerBlock") + require(colsPerBlock > 0, + s"colsPerBlock needs to be greater than 0. colsPerBlock: $colsPerBlock") + val m = numRows() + val n = numCols() + val numRowBlocks = math.ceil(m.toDouble / rowsPerBlock).toInt + val numColBlocks = math.ceil(n.toDouble / colsPerBlock).toInt + val partitioner = GridPartitioner(numRowBlocks, numColBlocks, entries.partitions.length) + + val blocks: RDD[((Int, Int), Matrix)] = entries.map { entry => + val blockRowIndex = (entry.i / rowsPerBlock).toInt + val blockColIndex = (entry.j / colsPerBlock).toInt + + val rowId = entry.i % rowsPerBlock + val colId = entry.j % colsPerBlock + + ((blockRowIndex, blockColIndex), (rowId.toInt, colId.toInt, entry.value)) + }.groupByKey(partitioner).map { case ((blockRowIndex, blockColIndex), entry) => + val effRows = math.min(m - blockRowIndex.toLong * rowsPerBlock, rowsPerBlock).toInt + val effCols = math.min(n - blockColIndex.toLong * colsPerBlock, colsPerBlock).toInt + ((blockRowIndex, blockColIndex), SparseMatrix.fromCOO(effRows, effCols, entry)) + } + new BlockMatrix(blocks, rowsPerBlock, colsPerBlock, m, n) + } + /** Determines the size by computing the max row/column index. */ private def computeSize() { // Reduce will throw an exception if `entries` is empty. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala index 181f507516485..3be530fa07537 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala @@ -75,6 +75,41 @@ class IndexedRowMatrix( new RowMatrix(rows.map(_.vector), 0L, nCols) } + /** Converts to BlockMatrix. Creates blocks of [[SparseMatrix]] with size 1024 x 1024. */ + def toBlockMatrix(): BlockMatrix = { + toBlockMatrix(1024, 1024) + } + + /** + * Converts to BlockMatrix. Creates blocks of [[SparseMatrix]]. + * @param rowsPerBlock The number of rows of each block. The blocks at the bottom edge may have + * a smaller value. Must be an integer value greater than 0. + * @param colsPerBlock The number of columns of each block. The blocks at the right edge may have + * a smaller value. Must be an integer value greater than 0. + * @return a [[BlockMatrix]] + */ + def toBlockMatrix(rowsPerBlock: Int, colsPerBlock: Int): BlockMatrix = { + // TODO: This implementation may be optimized + toCoordinateMatrix().toBlockMatrix(rowsPerBlock, colsPerBlock) + } + + /** + * Converts this matrix to a + * [[org.apache.spark.mllib.linalg.distributed.CoordinateMatrix]]. + */ + def toCoordinateMatrix(): CoordinateMatrix = { + val entries = rows.flatMap { row => + val rowIndex = row.index + row.vector match { + case SparseVector(size, indices, values) => + Iterator.tabulate(indices.size)(i => MatrixEntry(rowIndex, indices(i), values(i))) + case DenseVector(values) => + Iterator.tabulate(values.size)(i => MatrixEntry(rowIndex, i, values(i))) + } + } + new CoordinateMatrix(entries, numRows(), numCols()) + } + /** * Computes the singular value decomposition of this IndexedRowMatrix. * Denote this matrix by A (m x n), this will compute matrices U, S, V such that A = U * S * V'. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala index d5abba6a4b645..961111507f2c2 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala @@ -30,7 +30,6 @@ import org.apache.spark.Logging import org.apache.spark.SparkContext._ import org.apache.spark.annotation.Experimental import org.apache.spark.mllib.linalg._ -import org.apache.spark.mllib.rdd.RDDFunctions._ import org.apache.spark.mllib.stat.{MultivariateOnlineSummarizer, MultivariateStatisticalSummary} import org.apache.spark.rdd.RDD import org.apache.spark.util.random.XORShiftRandom @@ -131,8 +130,8 @@ class RowMatrix( throw new IllegalArgumentException(s"Argument with more than 65535 cols: $cols") } if (cols > 10000) { - val mem = cols * cols * 8 - logWarning(s"$cols columns will require at least $mem bytes of memory!") + val memMB = (cols.toLong * cols) / 125000 + logWarning(s"$cols columns will require at least $memMB megabytes of memory!") } } @@ -152,10 +151,10 @@ class RowMatrix( * storing the right singular vectors, is computed via matrix multiplication as * U = A * (V * S^-1^), if requested by user. The actual method to use is determined * automatically based on the cost: - * - If n is small (n < 100) or k is large compared with n (k > n / 2), we compute the Gramian - * matrix first and then compute its top eigenvalues and eigenvectors locally on the driver. - * This requires a single pass with O(n^2^) storage on each executor and on the driver, and - * O(n^2^ k) time on the driver. + * - If n is small (n < 100) or k is large compared with n (k > n / 2), we compute + * the Gramian matrix first and then compute its top eigenvalues and eigenvectors locally + * on the driver. This requires a single pass with O(n^2^) storage on each executor and + * on the driver, and O(n^2^ k) time on the driver. * - Otherwise, we compute (A' * A) * v in a distributive way and send it to ARPACK's DSAUPD to * compute (A' * A)'s top eigenvalues and eigenvectors on the driver node. This requires O(k) * passes, O(n) storage on each executor, and O(n k) storage on the driver. @@ -220,8 +219,12 @@ class RowMatrix( val computeMode = mode match { case "auto" => + if(k > 5000) { + logWarning(s"computing svd with k=$k and n=$n, please check necessity") + } + // TODO: The conditions below are not fully tested. - if (n < 100 || k > n / 2) { + if (n < 100 || (k > n / 2 && n <= 15000)) { // If n is small or k is large compared with n, we better compute the Gramian matrix first // and then compute its eigenvalues locally, instead of making multiple passes. if (k < n / 3) { @@ -246,6 +249,8 @@ class RowMatrix( val G = computeGramianMatrix().toBreeze.asInstanceOf[BDM[Double]] EigenValueDecomposition.symmetricEigs(v => G * v, n, k, tol, maxIter) case SVDMode.LocalLAPACK => + // breeze (v0.10) svd latent constraint, 7 * n * n + 4 * n < Int.MaxValue + require(n < 17515, s"$n exceeds the breeze svd capability") val G = computeGramianMatrix().toBreeze.asInstanceOf[BDM[Double]] val brzSvd.SVD(uFull: BDM[Double], sigmaSquaresFull: BDV[Double], _) = brzSvd(G) (sigmaSquaresFull, uFull) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala index 1ca0f36c6ac34..8bfa0d2b64995 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala @@ -18,7 +18,7 @@ package org.apache.spark.mllib.optimization import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.linalg.{DenseVector, Vector, Vectors} import org.apache.spark.mllib.linalg.BLAS.{axpy, dot, scal} import org.apache.spark.mllib.util.MLUtils @@ -55,24 +55,96 @@ abstract class Gradient extends Serializable { /** * :: DeveloperApi :: - * Compute gradient and loss for a logistic loss function, as used in binary classification. - * See also the documentation for the precise formulation. + * Compute gradient and loss for a multinomial logistic loss function, as used + * in multi-class classification (it is also used in binary logistic regression). + * + * In `The Elements of Statistical Learning: Data Mining, Inference, and Prediction, 2nd Edition` + * by Trevor Hastie, Robert Tibshirani, and Jerome Friedman, which can be downloaded from + * http://statweb.stanford.edu/~tibs/ElemStatLearn/ , Eq. (4.17) on page 119 gives the formula of + * multinomial logistic regression model. A simple calculation shows that + * + * {{{ + * P(y=0|x, w) = 1 / (1 + \sum_i^{K-1} \exp(x w_i)) + * P(y=1|x, w) = exp(x w_1) / (1 + \sum_i^{K-1} \exp(x w_i)) + * ... + * P(y=K-1|x, w) = exp(x w_{K-1}) / (1 + \sum_i^{K-1} \exp(x w_i)) + * }}} + * + * for K classes multiclass classification problem. + * + * The model weights w = (w_1, w_2, ..., w_{K-1})^T becomes a matrix which has dimension of + * (K-1) * (N+1) if the intercepts are added. If the intercepts are not added, the dimension + * will be (K-1) * N. + * + * As a result, the loss of objective function for a single instance of data can be written as + * {{{ + * l(w, x) = -log P(y|x, w) = -\alpha(y) log P(y=0|x, w) - (1-\alpha(y)) log P(y|x, w) + * = log(1 + \sum_i^{K-1}\exp(x w_i)) - (1-\alpha(y)) x w_{y-1} + * = log(1 + \sum_i^{K-1}\exp(margins_i)) - (1-\alpha(y)) margins_{y-1} + * }}} + * + * where \alpha(i) = 1 if i != 0, and + * \alpha(i) = 0 if i == 0, + * margins_i = x w_i. + * + * For optimization, we have to calculate the first derivative of the loss function, and + * a simple calculation shows that + * + * {{{ + * \frac{\partial l(w, x)}{\partial w_{ij}} + * = (\exp(x w_i) / (1 + \sum_k^{K-1} \exp(x w_k)) - (1-\alpha(y)\delta_{y, i+1})) * x_j + * = multiplier_i * x_j + * }}} + * + * where \delta_{i, j} = 1 if i == j, + * \delta_{i, j} = 0 if i != j, and + * multiplier = + * \exp(margins_i) / (1 + \sum_k^{K-1} \exp(margins_i)) - (1-\alpha(y)\delta_{y, i+1}) + * + * If any of margins is larger than 709.78, the numerical computation of multiplier and loss + * function will be suffered from arithmetic overflow. This issue occurs when there are outliers + * in data which are far away from hyperplane, and this will cause the failing of training once + * infinity / infinity is introduced. Note that this is only a concern when max(margins) > 0. + * + * Fortunately, when max(margins) = maxMargin > 0, the loss function and the multiplier can be + * easily rewritten into the following equivalent numerically stable formula. + * + * {{{ + * l(w, x) = log(1 + \sum_i^{K-1}\exp(margins_i)) - (1-\alpha(y)) margins_{y-1} + * = log(\exp(-maxMargin) + \sum_i^{K-1}\exp(margins_i - maxMargin)) + maxMargin + * - (1-\alpha(y)) margins_{y-1} + * = log(1 + sum) + maxMargin - (1-\alpha(y)) margins_{y-1} + * }}} + * + * where sum = \exp(-maxMargin) + \sum_i^{K-1}\exp(margins_i - maxMargin) - 1. + * + * Note that each term, (margins_i - maxMargin) in \exp is smaller than zero; as a result, + * overflow will not happen with this formula. + * + * For multiplier, similar trick can be applied as the following, + * + * {{{ + * multiplier = \exp(margins_i) / (1 + \sum_k^{K-1} \exp(margins_i)) - (1-\alpha(y)\delta_{y, i+1}) + * = \exp(margins_i - maxMargin) / (1 + sum) - (1-\alpha(y)\delta_{y, i+1}) + * }}} + * + * where each term in \exp is also smaller than zero, so overflow is not a concern. + * + * For the detailed mathematical derivation, see the reference at + * http://www.slideshare.net/dbtsai/2014-0620-mlor-36132297 + * + * @param numClasses the number of possible outcomes for k classes classification problem in + * Multinomial Logistic Regression. By default, it is binary logistic regression + * so numClasses will be set to 2. */ @DeveloperApi -class LogisticGradient extends Gradient { - override def compute(data: Vector, label: Double, weights: Vector): (Vector, Double) = { - val margin = -1.0 * dot(data, weights) - val gradientMultiplier = (1.0 / (1.0 + math.exp(margin))) - label - val gradient = data.copy - scal(gradientMultiplier, gradient) - val loss = - if (label > 0) { - // The following is equivalent to log(1 + exp(margin)) but more numerically stable. - MLUtils.log1pExp(margin) - } else { - MLUtils.log1pExp(margin) - margin - } +class LogisticGradient(numClasses: Int) extends Gradient { + def this() = this(2) + + override def compute(data: Vector, label: Double, weights: Vector): (Vector, Double) = { + val gradient = Vectors.zeros(weights.size) + val loss = compute(data, label, weights, gradient) (gradient, loss) } @@ -81,14 +153,104 @@ class LogisticGradient extends Gradient { label: Double, weights: Vector, cumGradient: Vector): Double = { - val margin = -1.0 * dot(data, weights) - val gradientMultiplier = (1.0 / (1.0 + math.exp(margin))) - label - axpy(gradientMultiplier, data, cumGradient) - if (label > 0) { - // The following is equivalent to log(1 + exp(margin)) but more numerically stable. - MLUtils.log1pExp(margin) - } else { - MLUtils.log1pExp(margin) - margin + val dataSize = data.size + + // (weights.size / dataSize + 1) is number of classes + require(weights.size % dataSize == 0 && numClasses == weights.size / dataSize + 1) + numClasses match { + case 2 => + /** + * For Binary Logistic Regression. + * + * Although the loss and gradient calculation for multinomial one is more generalized, + * and multinomial one can also be used in binary case, we still implement a specialized + * binary version for performance reason. + */ + val margin = -1.0 * dot(data, weights) + val multiplier = (1.0 / (1.0 + math.exp(margin))) - label + axpy(multiplier, data, cumGradient) + if (label > 0) { + // The following is equivalent to log(1 + exp(margin)) but more numerically stable. + MLUtils.log1pExp(margin) + } else { + MLUtils.log1pExp(margin) - margin + } + case _ => + /** + * For Multinomial Logistic Regression. + */ + val weightsArray = weights match { + case dv: DenseVector => dv.values + case _ => + throw new IllegalArgumentException( + s"weights only supports dense vector but got type ${weights.getClass}.") + } + val cumGradientArray = cumGradient match { + case dv: DenseVector => dv.values + case _ => + throw new IllegalArgumentException( + s"cumGradient only supports dense vector but got type ${cumGradient.getClass}.") + } + + // marginY is margins(label - 1) in the formula. + var marginY = 0.0 + var maxMargin = Double.NegativeInfinity + var maxMarginIndex = 0 + + val margins = Array.tabulate(numClasses - 1) { i => + var margin = 0.0 + data.foreachActive { (index, value) => + if (value != 0.0) margin += value * weightsArray((i * dataSize) + index) + } + if (i == label.toInt - 1) marginY = margin + if (margin > maxMargin) { + maxMargin = margin + maxMarginIndex = i + } + margin + } + + /** + * When maxMargin > 0, the original formula will cause overflow as we discuss + * in the previous comment. + * We address this by subtracting maxMargin from all the margins, so it's guaranteed + * that all of the new margins will be smaller than zero to prevent arithmetic overflow. + */ + val sum = { + var temp = 0.0 + if (maxMargin > 0) { + for (i <- 0 until numClasses - 1) { + margins(i) -= maxMargin + if (i == maxMarginIndex) { + temp += math.exp(-maxMargin) + } else { + temp += math.exp(margins(i)) + } + } + } else { + for (i <- 0 until numClasses - 1) { + temp += math.exp(margins(i)) + } + } + temp + } + + for (i <- 0 until numClasses - 1) { + val multiplier = math.exp(margins(i)) / (sum + 1.0) - { + if (label != 0.0 && label == i + 1) 1.0 else 0.0 + } + data.foreachActive { (index, value) => + if (value != 0.0) cumGradientArray(i * dataSize + index) += multiplier * value + } + } + + val loss = if (label > 0.0) math.log1p(sum) - marginY else math.log1p(sum) + + if (maxMargin > 0) { + loss + maxMargin + } else { + loss + } } } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala index 0857877951c82..4b7d0589c973b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala @@ -25,7 +25,6 @@ import org.apache.spark.annotation.{Experimental, DeveloperApi} import org.apache.spark.Logging import org.apache.spark.rdd.RDD import org.apache.spark.mllib.linalg.{Vectors, Vector} -import org.apache.spark.mllib.rdd.RDDFunctions._ /** * Class used to solve an optimization problem using Gradient Descent. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala index d16d0daf08565..d5e4f4ccbff10 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala @@ -26,7 +26,6 @@ import org.apache.spark.Logging import org.apache.spark.annotation.DeveloperApi import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.linalg.BLAS.axpy -import org.apache.spark.mllib.rdd.RDDFunctions._ import org.apache.spark.rdd.RDD /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/NNLS.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/NNLS.scala index fef062e02b6ec..ccd93b318bc23 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/NNLS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/NNLS.scala @@ -19,13 +19,11 @@ package org.apache.spark.mllib.optimization import org.jblas.{DoubleMatrix, SimpleBlas} -import org.apache.spark.annotation.DeveloperApi - /** * Object used to solve nonnegative least squares problems using a modified * projected gradient method. */ -private[mllib] object NNLS { +private[spark] object NNLS { class Workspace(val n: Int) { val scratch = new DoubleMatrix(n, 1) val grad = new DoubleMatrix(n, 1) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala index 955c593a085d5..8341bb86afd71 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala @@ -29,13 +29,13 @@ import org.apache.spark.util.Utils /** * :: Experimental :: - * Generator methods for creating RDDs comprised of i.i.d. samples from some distribution. + * Generator methods for creating RDDs comprised of `i.i.d.` samples from some distribution. */ @Experimental object RandomRDDs { /** - * Generates an RDD comprised of i.i.d. samples from the uniform distribution `U(0.0, 1.0)`. + * Generates an RDD comprised of `i.i.d.` samples from the uniform distribution `U(0.0, 1.0)`. * * To transform the distribution in the generated RDD from `U(0.0, 1.0)` to `U(a, b)`, use * `RandomRDDs.uniformRDD(sc, n, p, seed).map(v => a + (b - a) * v)`. @@ -44,7 +44,7 @@ object RandomRDDs { * @param size Size of the RDD. * @param numPartitions Number of partitions in the RDD (default: `sc.defaultParallelism`). * @param seed Random seed (default: a random long integer). - * @return RDD[Double] comprised of i.i.d. samples ~ `U(0.0, 1.0)`. + * @return RDD[Double] comprised of `i.i.d.` samples ~ `U(0.0, 1.0)`. */ def uniformRDD( sc: SparkContext, @@ -81,7 +81,7 @@ object RandomRDDs { } /** - * Generates an RDD comprised of i.i.d. samples from the standard normal distribution. + * Generates an RDD comprised of `i.i.d.` samples from the standard normal distribution. * * To transform the distribution in the generated RDD from standard normal to some other normal * `N(mean, sigma^2^)`, use `RandomRDDs.normalRDD(sc, n, p, seed).map(v => mean + sigma * v)`. @@ -90,7 +90,7 @@ object RandomRDDs { * @param size Size of the RDD. * @param numPartitions Number of partitions in the RDD (default: `sc.defaultParallelism`). * @param seed Random seed (default: a random long integer). - * @return RDD[Double] comprised of i.i.d. samples ~ N(0.0, 1.0). + * @return RDD[Double] comprised of `i.i.d.` samples ~ N(0.0, 1.0). */ def normalRDD( sc: SparkContext, @@ -127,14 +127,15 @@ object RandomRDDs { } /** - * Generates an RDD comprised of i.i.d. samples from the Poisson distribution with the input mean. + * Generates an RDD comprised of `i.i.d.` samples from the Poisson distribution with the input + * mean. * * @param sc SparkContext used to create the RDD. * @param mean Mean, or lambda, for the Poisson distribution. * @param size Size of the RDD. * @param numPartitions Number of partitions in the RDD (default: `sc.defaultParallelism`). * @param seed Random seed (default: a random long integer). - * @return RDD[Double] comprised of i.i.d. samples ~ Pois(mean). + * @return RDD[Double] comprised of `i.i.d.` samples ~ Pois(mean). */ def poissonRDD( sc: SparkContext, @@ -177,7 +178,7 @@ object RandomRDDs { } /** - * Generates an RDD comprised of i.i.d. samples from the exponential distribution with + * Generates an RDD comprised of `i.i.d.` samples from the exponential distribution with * the input mean. * * @param sc SparkContext used to create the RDD. @@ -185,7 +186,7 @@ object RandomRDDs { * @param size Size of the RDD. * @param numPartitions Number of partitions in the RDD (default: `sc.defaultParallelism`). * @param seed Random seed (default: a random long integer). - * @return RDD[Double] comprised of i.i.d. samples ~ Pois(mean). + * @return RDD[Double] comprised of `i.i.d.` samples ~ Pois(mean). */ def exponentialRDD( sc: SparkContext, @@ -228,7 +229,7 @@ object RandomRDDs { } /** - * Generates an RDD comprised of i.i.d. samples from the gamma distribution with the input + * Generates an RDD comprised of `i.i.d.` samples from the gamma distribution with the input * shape and scale. * * @param sc SparkContext used to create the RDD. @@ -237,7 +238,7 @@ object RandomRDDs { * @param size Size of the RDD. * @param numPartitions Number of partitions in the RDD (default: `sc.defaultParallelism`). * @param seed Random seed (default: a random long integer). - * @return RDD[Double] comprised of i.i.d. samples ~ Pois(mean). + * @return RDD[Double] comprised of `i.i.d.` samples ~ Pois(mean). */ def gammaRDD( sc: SparkContext, @@ -287,7 +288,7 @@ object RandomRDDs { } /** - * Generates an RDD comprised of i.i.d. samples from the log normal distribution with the input + * Generates an RDD comprised of `i.i.d.` samples from the log normal distribution with the input * mean and standard deviation * * @param sc SparkContext used to create the RDD. @@ -296,7 +297,7 @@ object RandomRDDs { * @param size Size of the RDD. * @param numPartitions Number of partitions in the RDD (default: `sc.defaultParallelism`). * @param seed Random seed (default: a random long integer). - * @return RDD[Double] comprised of i.i.d. samples ~ Pois(mean). + * @return RDD[Double] comprised of `i.i.d.` samples ~ Pois(mean). */ def logNormalRDD( sc: SparkContext, @@ -348,14 +349,14 @@ object RandomRDDs { /** * :: DeveloperApi :: - * Generates an RDD comprised of i.i.d. samples produced by the input RandomDataGenerator. + * Generates an RDD comprised of `i.i.d.` samples produced by the input RandomDataGenerator. * * @param sc SparkContext used to create the RDD. * @param generator RandomDataGenerator used to populate the RDD. * @param size Size of the RDD. * @param numPartitions Number of partitions in the RDD (default: `sc.defaultParallelism`). * @param seed Random seed (default: a random long integer). - * @return RDD[Double] comprised of i.i.d. samples produced by generator. + * @return RDD[Double] comprised of `i.i.d.` samples produced by generator. */ @DeveloperApi def randomRDD[T: ClassTag]( @@ -370,7 +371,7 @@ object RandomRDDs { // TODO Generate RDD[Vector] from multivariate distributions. /** - * Generates an RDD[Vector] with vectors containing i.i.d. samples drawn from the + * Generates an RDD[Vector] with vectors containing `i.i.d.` samples drawn from the * uniform distribution on `U(0.0, 1.0)`. * * @param sc SparkContext used to create the RDD. @@ -424,7 +425,7 @@ object RandomRDDs { } /** - * Generates an RDD[Vector] with vectors containing i.i.d. samples drawn from the + * Generates an RDD[Vector] with vectors containing `i.i.d.` samples drawn from the * standard normal distribution. * * @param sc SparkContext used to create the RDD. @@ -432,7 +433,7 @@ object RandomRDDs { * @param numCols Number of elements in each Vector. * @param numPartitions Number of partitions in the RDD (default: `sc.defaultParallelism`). * @param seed Random seed (default: a random long integer). - * @return RDD[Vector] with vectors containing i.i.d. samples ~ `N(0.0, 1.0)`. + * @return RDD[Vector] with vectors containing `i.i.d.` samples ~ `N(0.0, 1.0)`. */ def normalVectorRDD( sc: SparkContext, @@ -478,7 +479,7 @@ object RandomRDDs { } /** - * Generates an RDD[Vector] with vectors containing i.i.d. samples drawn from a + * Generates an RDD[Vector] with vectors containing `i.i.d.` samples drawn from a * log normal distribution. * * @param sc SparkContext used to create the RDD. @@ -488,7 +489,7 @@ object RandomRDDs { * @param numCols Number of elements in each Vector. * @param numPartitions Number of partitions in the RDD (default: `sc.defaultParallelism`). * @param seed Random seed (default: a random long integer). - * @return RDD[Vector] with vectors containing i.i.d. samples. + * @return RDD[Vector] with vectors containing `i.i.d.` samples. */ def logNormalVectorRDD( sc: SparkContext, @@ -544,7 +545,7 @@ object RandomRDDs { } /** - * Generates an RDD[Vector] with vectors containing i.i.d. samples drawn from the + * Generates an RDD[Vector] with vectors containing `i.i.d.` samples drawn from the * Poisson distribution with the input mean. * * @param sc SparkContext used to create the RDD. @@ -553,7 +554,7 @@ object RandomRDDs { * @param numCols Number of elements in each Vector. * @param numPartitions Number of partitions in the RDD (default: `sc.defaultParallelism`) * @param seed Random seed (default: a random long integer). - * @return RDD[Vector] with vectors containing i.i.d. samples ~ Pois(mean). + * @return RDD[Vector] with vectors containing `i.i.d.` samples ~ Pois(mean). */ def poissonVectorRDD( sc: SparkContext, @@ -603,7 +604,7 @@ object RandomRDDs { } /** - * Generates an RDD[Vector] with vectors containing i.i.d. samples drawn from the + * Generates an RDD[Vector] with vectors containing `i.i.d.` samples drawn from the * exponential distribution with the input mean. * * @param sc SparkContext used to create the RDD. @@ -612,7 +613,7 @@ object RandomRDDs { * @param numCols Number of elements in each Vector. * @param numPartitions Number of partitions in the RDD (default: `sc.defaultParallelism`) * @param seed Random seed (default: a random long integer). - * @return RDD[Vector] with vectors containing i.i.d. samples ~ Exp(mean). + * @return RDD[Vector] with vectors containing `i.i.d.` samples ~ Exp(mean). */ def exponentialVectorRDD( sc: SparkContext, @@ -665,7 +666,7 @@ object RandomRDDs { /** - * Generates an RDD[Vector] with vectors containing i.i.d. samples drawn from the + * Generates an RDD[Vector] with vectors containing `i.i.d.` samples drawn from the * gamma distribution with the input shape and scale. * * @param sc SparkContext used to create the RDD. @@ -675,7 +676,7 @@ object RandomRDDs { * @param numCols Number of elements in each Vector. * @param numPartitions Number of partitions in the RDD (default: `sc.defaultParallelism`) * @param seed Random seed (default: a random long integer). - * @return RDD[Vector] with vectors containing i.i.d. samples ~ Exp(mean). + * @return RDD[Vector] with vectors containing `i.i.d.` samples ~ Exp(mean). */ def gammaVectorRDD( sc: SparkContext, @@ -731,7 +732,7 @@ object RandomRDDs { /** * :: DeveloperApi :: - * Generates an RDD[Vector] with vectors containing i.i.d. samples produced by the + * Generates an RDD[Vector] with vectors containing `i.i.d.` samples produced by the * input RandomDataGenerator. * * @param sc SparkContext used to create the RDD. @@ -740,7 +741,7 @@ object RandomRDDs { * @param numCols Number of elements in each Vector. * @param numPartitions Number of partitions in the RDD (default: `sc.defaultParallelism`). * @param seed Random seed (default: a random long integer). - * @return RDD[Vector] with vectors containing i.i.d. samples produced by generator. + * @return RDD[Vector] with vectors containing `i.i.d.` samples produced by generator. */ @DeveloperApi def randomVectorRDD(sc: SparkContext, diff --git a/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala b/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala index 57c0768084e41..78172843be56e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala @@ -21,10 +21,7 @@ import scala.language.implicitConversions import scala.reflect.ClassTag import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.HashPartitioner -import org.apache.spark.SparkContext._ import org.apache.spark.rdd.RDD -import org.apache.spark.util.Utils /** * Machine learning specific RDD functions. @@ -53,63 +50,25 @@ class RDDFunctions[T: ClassTag](self: RDD[T]) extends Serializable { * Reduces the elements of this RDD in a multi-level tree pattern. * * @param depth suggested depth of the tree (default: 2) - * @see [[org.apache.spark.rdd.RDD#reduce]] + * @see [[org.apache.spark.rdd.RDD#treeReduce]] + * @deprecated Use [[org.apache.spark.rdd.RDD#treeReduce]] instead. */ - def treeReduce(f: (T, T) => T, depth: Int = 2): T = { - require(depth >= 1, s"Depth must be greater than or equal to 1 but got $depth.") - val cleanF = self.context.clean(f) - val reducePartition: Iterator[T] => Option[T] = iter => { - if (iter.hasNext) { - Some(iter.reduceLeft(cleanF)) - } else { - None - } - } - val partiallyReduced = self.mapPartitions(it => Iterator(reducePartition(it))) - val op: (Option[T], Option[T]) => Option[T] = (c, x) => { - if (c.isDefined && x.isDefined) { - Some(cleanF(c.get, x.get)) - } else if (c.isDefined) { - c - } else if (x.isDefined) { - x - } else { - None - } - } - RDDFunctions.fromRDD(partiallyReduced).treeAggregate(Option.empty[T])(op, op, depth) - .getOrElse(throw new UnsupportedOperationException("empty collection")) - } + @deprecated("Use RDD.treeReduce instead.", "1.3.0") + def treeReduce(f: (T, T) => T, depth: Int = 2): T = self.treeReduce(f, depth) /** * Aggregates the elements of this RDD in a multi-level tree pattern. * * @param depth suggested depth of the tree (default: 2) - * @see [[org.apache.spark.rdd.RDD#aggregate]] + * @see [[org.apache.spark.rdd.RDD#treeAggregate]] + * @deprecated Use [[org.apache.spark.rdd.RDD#treeAggregate]] instead. */ + @deprecated("Use RDD.treeAggregate instead.", "1.3.0") def treeAggregate[U: ClassTag](zeroValue: U)( seqOp: (U, T) => U, combOp: (U, U) => U, depth: Int = 2): U = { - require(depth >= 1, s"Depth must be greater than or equal to 1 but got $depth.") - if (self.partitions.size == 0) { - return Utils.clone(zeroValue, self.context.env.closureSerializer.newInstance()) - } - val cleanSeqOp = self.context.clean(seqOp) - val cleanCombOp = self.context.clean(combOp) - val aggregatePartition = (it: Iterator[T]) => it.aggregate(zeroValue)(cleanSeqOp, cleanCombOp) - var partiallyAggregated = self.mapPartitions(it => Iterator(aggregatePartition(it))) - var numPartitions = partiallyAggregated.partitions.size - val scale = math.max(math.ceil(math.pow(numPartitions, 1.0 / depth)).toInt, 2) - // If creating an extra level doesn't help reduce the wall-clock time, we stop tree aggregation. - while (numPartitions > scale + numPartitions / scale) { - numPartitions /= scale - val curNumPartitions = numPartitions - partiallyAggregated = partiallyAggregated.mapPartitionsWithIndex { (i, iter) => - iter.map((i % curNumPartitions, _)) - }.reduceByKey(new HashPartitioner(curNumPartitions), cleanCombOp).values - } - partiallyAggregated.reduce(cleanCombOp) + self.treeAggregate(zeroValue)(seqOp, combOp, depth) } } @@ -117,5 +76,5 @@ class RDDFunctions[T: ClassTag](self: RDD[T]) extends Serializable { object RDDFunctions { /** Implicit conversion from an RDD to RDDFunctions. */ - implicit def fromRDD[T: ClassTag](rdd: RDD[T]) = new RDDFunctions[T](rdd) + implicit def fromRDD[T: ClassTag](rdd: RDD[T]): RDDFunctions[T] = new RDDFunctions[T](rdd) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala index bee951a2e5e26..caacab943030b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala @@ -17,52 +17,16 @@ package org.apache.spark.mllib.recommendation -import scala.collection.mutable -import scala.collection.mutable.ArrayBuffer -import scala.math.{abs, sqrt} -import scala.util.{Random, Sorting} -import scala.util.hashing.byteswap32 - -import org.jblas.{DoubleMatrix, SimpleBlas, Solve} - -import org.apache.spark.{HashPartitioner, Logging, Partitioner} -import org.apache.spark.SparkContext._ -import org.apache.spark.annotation.{DeveloperApi, Experimental} +import org.apache.spark.Logging +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.api.java.JavaRDD -import org.apache.spark.broadcast.Broadcast -import org.apache.spark.mllib.optimization.NNLS +import org.apache.spark.ml.recommendation.{ALS => NewALS} import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel -import org.apache.spark.util.Utils - -/** - * Out-link information for a user or product block. This includes the original user/product IDs - * of the elements within this block, and the list of destination blocks that each user or - * product will need to send its feature vector to. - */ -private[recommendation] -case class OutLinkBlock(elementIds: Array[Int], shouldSend: Array[mutable.BitSet]) - - -/** - * In-link information for a user (or product) block. This includes the original user/product IDs - * of the elements within this block, as well as an array of indices and ratings that specify - * which user in the block will be rated by which products from each product block (or vice-versa). - * Specifically, if this InLinkBlock is for users, ratingsForBlock(b)(i) will contain two arrays, - * indices and ratings, for the i'th product that will be sent to us by product block b (call this - * P). These arrays represent the users that product P had ratings for (by their index in this - * block), as well as the corresponding rating for each one. We can thus use this information when - * we get product block b's message to update the corresponding users. - */ -private[recommendation] case class InLinkBlock( - elementIds: Array[Int], ratingsForBlock: Array[Array[(Array[Int], Array[Double])]]) - /** - * :: Experimental :: * A more compact class to represent a rating than Tuple3[Int, Int, Double]. */ -@Experimental case class Rating(user: Int, product: Int, rating: Double) /** @@ -90,7 +54,7 @@ case class Rating(user: Int, product: Int, rating: Double) * * Essentially instead of finding the low-rank approximations to the rating matrix `R`, * this finds the approximations for a preference matrix `P` where the elements of `P` are 1 if - * r > 0 and 0 if r = 0. The ratings then act as 'confidence' values related to strength of + * r > 0 and 0 if r <= 0. The ratings then act as 'confidence' values related to strength of * indicated user * preferences rather than explicit ratings given to items. */ @@ -169,10 +133,8 @@ class ALS private ( } /** - * :: Experimental :: * Sets the constant used in computing confidence in implicit ALS. Default: 1.0. */ - @Experimental def setAlpha(alpha: Double): this.type = { this.alpha = alpha this @@ -201,6 +163,8 @@ class ALS private ( */ @DeveloperApi def setIntermediateRDDStorageLevel(storageLevel: StorageLevel): this.type = { + require(storageLevel != StorageLevel.NONE, + "ALS is not designed to run without persisting intermediate RDDs.") this.intermediateRDDStorageLevel = storageLevel this } @@ -236,431 +200,39 @@ class ALS private ( this.numProductBlocks } - val userPartitioner = new ALSPartitioner(numUserBlocks) - val productPartitioner = new ALSPartitioner(numProductBlocks) - - val ratingsByUserBlock = ratings.map { rating => - (userPartitioner.getPartition(rating.user), rating) - } - val ratingsByProductBlock = ratings.map { rating => - (productPartitioner.getPartition(rating.product), - Rating(rating.product, rating.user, rating.rating)) - } - - val (userInLinks, userOutLinks) = - makeLinkRDDs(numUserBlocks, numProductBlocks, ratingsByUserBlock, productPartitioner) - val (productInLinks, productOutLinks) = - makeLinkRDDs(numProductBlocks, numUserBlocks, ratingsByProductBlock, userPartitioner) - userInLinks.setName("userInLinks") - userOutLinks.setName("userOutLinks") - productInLinks.setName("productInLinks") - productOutLinks.setName("productOutLinks") - - // Initialize user and product factors randomly, but use a deterministic seed for each - // partition so that fault recovery works - val seedGen = new Random(seed) - val seed1 = seedGen.nextInt() - val seed2 = seedGen.nextInt() - var users = userOutLinks.mapPartitionsWithIndex { (index, itr) => - val rand = new Random(byteswap32(seed1 ^ index)) - itr.map { case (x, y) => - (x, y.elementIds.map(_ => randomFactor(rank, rand))) - } - } - var products = productOutLinks.mapPartitionsWithIndex { (index, itr) => - val rand = new Random(byteswap32(seed2 ^ index)) - itr.map { case (x, y) => - (x, y.elementIds.map(_ => randomFactor(rank, rand))) - } - } - - if (implicitPrefs) { - for (iter <- 1 to iterations) { - // perform ALS update - logInfo("Re-computing I given U (Iteration %d/%d)".format(iter, iterations)) - // Persist users because it will be called twice. - users.setName(s"users-$iter").persist() - val YtY = Some(sc.broadcast(computeYtY(users))) - val previousProducts = products - products = updateFeatures(numProductBlocks, users, userOutLinks, productInLinks, - rank, lambda, alpha, YtY) - previousProducts.unpersist() - logInfo("Re-computing U given I (Iteration %d/%d)".format(iter, iterations)) - if (sc.checkpointDir.isDefined && (iter % 3 == 0)) { - products.checkpoint() - } - products.setName(s"products-$iter").persist() - val XtX = Some(sc.broadcast(computeYtY(products))) - val previousUsers = users - users = updateFeatures(numUserBlocks, products, productOutLinks, userInLinks, - rank, lambda, alpha, XtX) - previousUsers.unpersist() - } - } else { - for (iter <- 1 to iterations) { - // perform ALS update - logInfo("Re-computing I given U (Iteration %d/%d)".format(iter, iterations)) - products = updateFeatures(numProductBlocks, users, userOutLinks, productInLinks, - rank, lambda, alpha, YtY = None) - if (sc.checkpointDir.isDefined && (iter % 3 == 0)) { - products.checkpoint() - } - products.setName(s"products-$iter") - logInfo("Re-computing U given I (Iteration %d/%d)".format(iter, iterations)) - users = updateFeatures(numUserBlocks, products, productOutLinks, userInLinks, - rank, lambda, alpha, YtY = None) - users.setName(s"users-$iter") - } + val (floatUserFactors, floatProdFactors) = NewALS.train[Int]( + ratings = ratings.map(r => NewALS.Rating(r.user, r.product, r.rating.toFloat)), + rank = rank, + numUserBlocks = numUserBlocks, + numItemBlocks = numProductBlocks, + maxIter = iterations, + regParam = lambda, + implicitPrefs = implicitPrefs, + alpha = alpha, + nonnegative = nonnegative, + intermediateRDDStorageLevel = intermediateRDDStorageLevel, + finalRDDStorageLevel = StorageLevel.NONE, + seed = seed) + + val userFactors = floatUserFactors + .mapValues(_.map(_.toDouble)) + .setName("users") + .persist(finalRDDStorageLevel) + val prodFactors = floatProdFactors + .mapValues(_.map(_.toDouble)) + .setName("products") + .persist(finalRDDStorageLevel) + if (finalRDDStorageLevel != StorageLevel.NONE) { + userFactors.count() + prodFactors.count() } - - // The last `products` will be used twice. One to generate the last `users` and the other to - // generate `productsOut`. So we cache it for better performance. - products.setName("products").persist() - - // Flatten and cache the two final RDDs to un-block them - val usersOut = unblockFactors(users, userOutLinks) - val productsOut = unblockFactors(products, productOutLinks) - - usersOut.setName("usersOut").persist(finalRDDStorageLevel) - productsOut.setName("productsOut").persist(finalRDDStorageLevel) - - // Materialize usersOut and productsOut. - usersOut.count() - productsOut.count() - - products.unpersist() - - // Clean up. - userInLinks.unpersist() - userOutLinks.unpersist() - productInLinks.unpersist() - productOutLinks.unpersist() - - new MatrixFactorizationModel(rank, usersOut, productsOut) + new MatrixFactorizationModel(rank, userFactors, prodFactors) } /** * Java-friendly version of [[ALS.run]]. */ def run(ratings: JavaRDD[Rating]): MatrixFactorizationModel = run(ratings.rdd) - - /** - * Computes the (`rank x rank`) matrix `YtY`, where `Y` is the (`nui x rank`) matrix of factors - * for each user (or product), in a distributed fashion. - * - * @param factors the (block-distributed) user or product factor vectors - * @return YtY - whose value is only used in the implicit preference model - */ - private def computeYtY(factors: RDD[(Int, Array[Array[Double]])]) = { - val n = rank * (rank + 1) / 2 - val LYtY = factors.values.aggregate(new DoubleMatrix(n))( seqOp = (L, Y) => { - Y.foreach(y => dspr(1.0, wrapDoubleArray(y), L)) - L - }, combOp = (L1, L2) => { - L1.addi(L2) - }) - val YtY = new DoubleMatrix(rank, rank) - fillFullMatrix(LYtY, YtY) - YtY - } - - /** - * Adds alpha * x * x.t to a matrix in-place. This is the same as BLAS's DSPR. - * - * @param L the lower triangular part of the matrix packed in an array (row major) - */ - private def dspr(alpha: Double, x: DoubleMatrix, L: DoubleMatrix) = { - val n = x.length - var i = 0 - var j = 0 - var idx = 0 - var axi = 0.0 - val xd = x.data - val Ld = L.data - while (i < n) { - axi = alpha * xd(i) - j = 0 - while (j <= i) { - Ld(idx) += axi * xd(j) - j += 1 - idx += 1 - } - i += 1 - } - } - - /** - * Wrap a double array in a DoubleMatrix without creating garbage. - * This is a temporary fix for jblas 1.2.3; it should be safe to move back to the - * DoubleMatrix(double[]) constructor come jblas 1.2.4. - */ - private def wrapDoubleArray(v: Array[Double]): DoubleMatrix = { - new DoubleMatrix(v.length, 1, v: _*) - } - - /** - * Flatten out blocked user or product factors into an RDD of (id, factor vector) pairs - */ - private def unblockFactors( - blockedFactors: RDD[(Int, Array[Array[Double]])], - outLinks: RDD[(Int, OutLinkBlock)]): RDD[(Int, Array[Double])] = { - blockedFactors.join(outLinks).flatMap { case (b, (factors, outLinkBlock)) => - for (i <- 0 until factors.length) yield (outLinkBlock.elementIds(i), factors(i)) - } - } - - /** - * Make the out-links table for a block of the users (or products) dataset given the list of - * (user, product, rating) values for the users in that block (or the opposite for products). - */ - private def makeOutLinkBlock(numProductBlocks: Int, ratings: Array[Rating], - productPartitioner: Partitioner): OutLinkBlock = { - val userIds = ratings.map(_.user).distinct.sorted - val numUsers = userIds.length - val userIdToPos = userIds.zipWithIndex.toMap - val shouldSend = Array.fill(numUsers)(new mutable.BitSet(numProductBlocks)) - for (r <- ratings) { - shouldSend(userIdToPos(r.user))(productPartitioner.getPartition(r.product)) = true - } - OutLinkBlock(userIds, shouldSend) - } - - /** - * Make the in-links table for a block of the users (or products) dataset given a list of - * (user, product, rating) values for the users in that block (or the opposite for products). - */ - private def makeInLinkBlock(numProductBlocks: Int, ratings: Array[Rating], - productPartitioner: Partitioner): InLinkBlock = { - val userIds = ratings.map(_.user).distinct.sorted - val userIdToPos = userIds.zipWithIndex.toMap - // Split out our ratings by product block - val blockRatings = Array.fill(numProductBlocks)(new ArrayBuffer[Rating]) - for (r <- ratings) { - blockRatings(productPartitioner.getPartition(r.product)) += r - } - val ratingsForBlock = new Array[Array[(Array[Int], Array[Double])]](numProductBlocks) - for (productBlock <- 0 until numProductBlocks) { - // Create an array of (product, Seq(Rating)) ratings - val groupedRatings = blockRatings(productBlock).groupBy(_.product).toArray - // Sort them by product ID - val ordering = new Ordering[(Int, ArrayBuffer[Rating])] { - def compare(a: (Int, ArrayBuffer[Rating]), b: (Int, ArrayBuffer[Rating])): Int = - a._1 - b._1 - } - Sorting.quickSort(groupedRatings)(ordering) - // Translate the user IDs to indices based on userIdToPos - ratingsForBlock(productBlock) = groupedRatings.map { case (p, rs) => - (rs.view.map(r => userIdToPos(r.user)).toArray, rs.view.map(_.rating).toArray) - } - } - InLinkBlock(userIds, ratingsForBlock) - } - - /** - * Make RDDs of InLinkBlocks and OutLinkBlocks given an RDD of (blockId, (u, p, r)) values for - * the users (or (blockId, (p, u, r)) for the products). We create these simultaneously to avoid - * having to shuffle the (blockId, (u, p, r)) RDD twice, or to cache it. - */ - private def makeLinkRDDs( - numUserBlocks: Int, - numProductBlocks: Int, - ratingsByUserBlock: RDD[(Int, Rating)], - productPartitioner: Partitioner): (RDD[(Int, InLinkBlock)], RDD[(Int, OutLinkBlock)]) = { - val grouped = ratingsByUserBlock.partitionBy(new HashPartitioner(numUserBlocks)) - val links = grouped.mapPartitionsWithIndex((blockId, elements) => { - val ratings = elements.map(_._2).toArray - val inLinkBlock = makeInLinkBlock(numProductBlocks, ratings, productPartitioner) - val outLinkBlock = makeOutLinkBlock(numProductBlocks, ratings, productPartitioner) - Iterator.single((blockId, (inLinkBlock, outLinkBlock))) - }, preservesPartitioning = true) - val inLinks = links.mapValues(_._1) - val outLinks = links.mapValues(_._2) - inLinks.persist(intermediateRDDStorageLevel) - outLinks.persist(intermediateRDDStorageLevel) - (inLinks, outLinks) - } - - /** - * Make a random factor vector with the given random. - */ - private def randomFactor(rank: Int, rand: Random): Array[Double] = { - // Choose a unit vector uniformly at random from the unit sphere, but from the - // "first quadrant" where all elements are nonnegative. This can be done by choosing - // elements distributed as Normal(0,1) and taking the absolute value, and then normalizing. - // This appears to create factorizations that have a slightly better reconstruction - // (<1%) compared picking elements uniformly at random in [0,1]. - val factor = Array.fill(rank)(abs(rand.nextGaussian())) - val norm = sqrt(factor.map(x => x * x).sum) - factor.map(x => x / norm) - } - - /** - * Compute the user feature vectors given the current products (or vice-versa). This first joins - * the products with their out-links to generate a set of messages to each destination block - * (specifically, the features for the products that user block cares about), then groups these - * by destination and joins them with the in-link info to figure out how to update each user. - * It returns an RDD of new feature vectors for each user block. - */ - private def updateFeatures( - numUserBlocks: Int, - products: RDD[(Int, Array[Array[Double]])], - productOutLinks: RDD[(Int, OutLinkBlock)], - userInLinks: RDD[(Int, InLinkBlock)], - rank: Int, - lambda: Double, - alpha: Double, - YtY: Option[Broadcast[DoubleMatrix]]): RDD[(Int, Array[Array[Double]])] = { - productOutLinks.join(products).flatMap { case (bid, (outLinkBlock, factors)) => - val toSend = Array.fill(numUserBlocks)(new ArrayBuffer[Array[Double]]) - for (p <- 0 until outLinkBlock.elementIds.length; userBlock <- 0 until numUserBlocks) { - if (outLinkBlock.shouldSend(p)(userBlock)) { - toSend(userBlock) += factors(p) - } - } - toSend.zipWithIndex.map{ case (buf, idx) => (idx, (bid, buf.toArray)) } - }.groupByKey(new HashPartitioner(numUserBlocks)) - .join(userInLinks) - .mapValues{ case (messages, inLinkBlock) => - updateBlock(messages, inLinkBlock, rank, lambda, alpha, YtY) - } - } - - /** - * Compute the new feature vectors for a block of the users matrix given the list of factors - * it received from each product and its InLinkBlock. - */ - private def updateBlock(messages: Iterable[(Int, Array[Array[Double]])], inLinkBlock: InLinkBlock, - rank: Int, lambda: Double, alpha: Double, YtY: Option[Broadcast[DoubleMatrix]]) - : Array[Array[Double]] = - { - // Sort the incoming block factor messages by block ID and make them an array - val blockFactors = messages.toSeq.sortBy(_._1).map(_._2).toArray // Array[Array[Double]] - val numProductBlocks = blockFactors.length - val numUsers = inLinkBlock.elementIds.length - - // We'll sum up the XtXes using vectors that represent only the lower-triangular part, since - // the matrices are symmetric - val triangleSize = rank * (rank + 1) / 2 - val userXtX = Array.fill(numUsers)(DoubleMatrix.zeros(triangleSize)) - val userXy = Array.fill(numUsers)(DoubleMatrix.zeros(rank)) - - // Some temp variables to avoid memory allocation - val tempXtX = DoubleMatrix.zeros(triangleSize) - val fullXtX = DoubleMatrix.zeros(rank, rank) - - // Count the number of ratings each user gives to provide user-specific regularization - val numRatings = Array.fill(numUsers)(0) - - // Compute the XtX and Xy values for each user by adding products it rated in each product - // block - for (productBlock <- 0 until numProductBlocks) { - var p = 0 - while (p < blockFactors(productBlock).length) { - val x = wrapDoubleArray(blockFactors(productBlock)(p)) - tempXtX.fill(0.0) - dspr(1.0, x, tempXtX) - val (us, rs) = inLinkBlock.ratingsForBlock(productBlock)(p) - if (implicitPrefs) { - var i = 0 - while (i < us.length) { - numRatings(us(i)) += 1 - // Extension to the original paper to handle rs(i) < 0. confidence is a function - // of |rs(i)| instead so that it is never negative: - val confidence = 1 + alpha * abs(rs(i)) - SimpleBlas.axpy(confidence - 1.0, tempXtX, userXtX(us(i))) - // For rs(i) < 0, the corresponding entry in P is 0 now, not 1 -- negative rs(i) - // means we try to reconstruct 0. We add terms only where P = 1, so, term below - // is now only added for rs(i) > 0: - if (rs(i) > 0) { - SimpleBlas.axpy(confidence, x, userXy(us(i))) - } - i += 1 - } - } else { - var i = 0 - while (i < us.length) { - numRatings(us(i)) += 1 - userXtX(us(i)).addi(tempXtX) - SimpleBlas.axpy(rs(i), x, userXy(us(i))) - i += 1 - } - } - p += 1 - } - } - - val ws = if (nonnegative) NNLS.createWorkspace(rank) else null - - // Solve the least-squares problem for each user and return the new feature vectors - Array.range(0, numUsers).map { index => - // Compute the full XtX matrix from the lower-triangular part we got above - fillFullMatrix(userXtX(index), fullXtX) - // Add regularization - val regParam = numRatings(index) * lambda - var i = 0 - while (i < rank) { - fullXtX.data(i * rank + i) += regParam - i += 1 - } - // Solve the resulting matrix, which is symmetric and positive-definite - if (implicitPrefs) { - solveLeastSquares(fullXtX.addi(YtY.get.value), userXy(index), ws) - } else { - solveLeastSquares(fullXtX, userXy(index), ws) - } - } - } - - /** - * Given A^T A and A^T b, find the x minimising ||Ax - b||_2, possibly subject - * to nonnegativity constraints if `nonnegative` is true. - */ - def solveLeastSquares(ata: DoubleMatrix, atb: DoubleMatrix, - ws: NNLS.Workspace): Array[Double] = { - if (!nonnegative) { - Solve.solvePositive(ata, atb).data - } else { - NNLS.solve(ata, atb, ws) - } - } - - /** - * Given a triangular matrix in the order of fillXtX above, compute the full symmetric square - * matrix that it represents, storing it into destMatrix. - */ - private def fillFullMatrix(triangularMatrix: DoubleMatrix, destMatrix: DoubleMatrix) { - val rank = destMatrix.rows - var i = 0 - var pos = 0 - while (i < rank) { - var j = 0 - while (j <= i) { - destMatrix.data(i*rank + j) = triangularMatrix.data(pos) - destMatrix.data(j*rank + i) = triangularMatrix.data(pos) - pos += 1 - j += 1 - } - i += 1 - } - } -} - -/** - * Partitioner for ALS. - */ -private[recommendation] class ALSPartitioner(override val numPartitions: Int) extends Partitioner { - override def getPartition(key: Any): Int = { - Utils.nonNegativeMod(byteswap32(key.asInstanceOf[Int]), numPartitions) - } - - override def equals(obj: Any): Boolean = { - obj match { - case p: ALSPartitioner => - this.numPartitions == p.numPartitions - case _ => - false - } - } } /** @@ -834,120 +406,4 @@ object ALS { : MatrixFactorizationModel = { trainImplicit(ratings, rank, iterations, 0.01, -1, 1.0) } - - /** - * :: DeveloperApi :: - * Statistics of a block in ALS computation. - * - * @param category type of this block, "user" or "product" - * @param index index of this block - * @param count number of users or products inside this block, the same as the number of - * least-squares problems to solve on this block in each iteration - * @param numRatings total number of ratings inside this block, the same as the number of outer - * products we need to make on this block in each iteration - * @param numInLinks total number of incoming links, the same as the number of vectors to retrieve - * before each iteration - * @param numOutLinks total number of outgoing links, the same as the number of vectors to send - * for the next iteration - */ - @DeveloperApi - case class BlockStats( - category: String, - index: Int, - count: Long, - numRatings: Long, - numInLinks: Long, - numOutLinks: Long) - - /** - * :: DeveloperApi :: - * Given an RDD of ratings, number of user blocks, and number of product blocks, computes the - * statistics of each block in ALS computation. This is useful for estimating cost and diagnosing - * load balance. - * - * @param ratings an RDD of ratings - * @param numUserBlocks number of user blocks - * @param numProductBlocks number of product blocks - * @return statistics of user blocks and product blocks - */ - @DeveloperApi - def analyzeBlocks( - ratings: RDD[Rating], - numUserBlocks: Int, - numProductBlocks: Int): Array[BlockStats] = { - - val userPartitioner = new ALSPartitioner(numUserBlocks) - val productPartitioner = new ALSPartitioner(numProductBlocks) - - val ratingsByUserBlock = ratings.map { rating => - (userPartitioner.getPartition(rating.user), rating) - } - val ratingsByProductBlock = ratings.map { rating => - (productPartitioner.getPartition(rating.product), - Rating(rating.product, rating.user, rating.rating)) - } - - val als = new ALS() - val (userIn, userOut) = - als.makeLinkRDDs(numUserBlocks, numProductBlocks, ratingsByUserBlock, userPartitioner) - val (prodIn, prodOut) = - als.makeLinkRDDs(numProductBlocks, numUserBlocks, ratingsByProductBlock, productPartitioner) - - def sendGrid(outLinks: RDD[(Int, OutLinkBlock)]): Map[(Int, Int), Long] = { - outLinks.map { x => - val grid = new mutable.HashMap[(Int, Int), Long]() - val uPartition = x._1 - x._2.shouldSend.foreach { ss => - ss.foreach { pPartition => - val pair = (uPartition, pPartition) - grid.put(pair, grid.getOrElse(pair, 0L) + 1L) - } - } - grid - }.reduce { (grid1, grid2) => - grid2.foreach { x => - grid1.put(x._1, grid1.getOrElse(x._1, 0L) + x._2) - } - grid1 - }.toMap - } - - val userSendGrid = sendGrid(userOut) - val prodSendGrid = sendGrid(prodOut) - - val userInbound = new Array[Long](numUserBlocks) - val prodInbound = new Array[Long](numProductBlocks) - val userOutbound = new Array[Long](numUserBlocks) - val prodOutbound = new Array[Long](numProductBlocks) - - for (u <- 0 until numUserBlocks; p <- 0 until numProductBlocks) { - userOutbound(u) += userSendGrid.getOrElse((u, p), 0L) - prodInbound(p) += userSendGrid.getOrElse((u, p), 0L) - userInbound(u) += prodSendGrid.getOrElse((p, u), 0L) - prodOutbound(p) += prodSendGrid.getOrElse((p, u), 0L) - } - - val userCounts = userOut.mapValues(x => x.elementIds.length).collectAsMap() - val prodCounts = prodOut.mapValues(x => x.elementIds.length).collectAsMap() - - val userRatings = countRatings(userIn) - val prodRatings = countRatings(prodIn) - - val userStats = Array.tabulate(numUserBlocks)( - u => BlockStats("user", u, userCounts(u), userRatings(u), userInbound(u), userOutbound(u))) - val productStatus = Array.tabulate(numProductBlocks)( - p => BlockStats("product", p, prodCounts(p), prodRatings(p), prodInbound(p), prodOutbound(p))) - - (userStats ++ productStatus).toArray - } - - private def countRatings(inLinks: RDD[(Int, InLinkBlock)]): Map[Int, Long] = { - inLinks.mapValues { ilb => - var numRatings = 0L - ilb.ratingsForBlock.foreach { ar => - ar.foreach { p => numRatings += p._1.length } - } - numRatings - }.collectAsMap().toMap - } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala index ed2f8b41bcae5..c399496568bfb 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala @@ -17,13 +17,20 @@ package org.apache.spark.mllib.recommendation +import java.io.IOException import java.lang.{Integer => JavaInteger} +import org.apache.hadoop.fs.Path import org.jblas.DoubleMatrix +import org.json4s._ +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods._ -import org.apache.spark.Logging +import org.apache.spark.{Logging, SparkContext} import org.apache.spark.api.java.{JavaPairRDD, JavaRDD} +import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{Row, SQLContext} import org.apache.spark.storage.StorageLevel /** @@ -41,7 +48,8 @@ import org.apache.spark.storage.StorageLevel class MatrixFactorizationModel( val rank: Int, val userFeatures: RDD[(Int, Array[Double])], - val productFeatures: RDD[(Int, Array[Double])]) extends Serializable with Logging { + val productFeatures: RDD[(Int, Array[Double])]) + extends Saveable with Serializable with Logging { require(rank > 0) validateFeatures("User", userFeatures) @@ -125,6 +133,12 @@ class MatrixFactorizationModel( recommend(productFeatures.lookup(product).head, userFeatures, num) .map(t => Rating(t._1, product, t._2)) + protected override val formatVersion: String = "1.0" + + override def save(sc: SparkContext, path: String): Unit = { + MatrixFactorizationModel.SaveLoadV1_0.save(this, path) + } + private def recommend( recommendToFeatures: Array[Double], recommendableFeatures: RDD[(Int, Array[Double])], @@ -136,3 +150,71 @@ class MatrixFactorizationModel( scored.top(num)(Ordering.by(_._2)) } } + +object MatrixFactorizationModel extends Loader[MatrixFactorizationModel] { + + import org.apache.spark.mllib.util.Loader._ + + override def load(sc: SparkContext, path: String): MatrixFactorizationModel = { + val (loadedClassName, formatVersion, _) = loadMetadata(sc, path) + val classNameV1_0 = SaveLoadV1_0.thisClassName + (loadedClassName, formatVersion) match { + case (className, "1.0") if className == classNameV1_0 => + SaveLoadV1_0.load(sc, path) + case _ => + throw new IOException("MatrixFactorizationModel.load did not recognize model with" + + s"(class: $loadedClassName, version: $formatVersion). Supported:\n" + + s" ($classNameV1_0, 1.0)") + } + } + + private[recommendation] + object SaveLoadV1_0 { + + private val thisFormatVersion = "1.0" + + private[recommendation] + val thisClassName = "org.apache.spark.mllib.recommendation.MatrixFactorizationModel" + + /** + * Saves a [[MatrixFactorizationModel]], where user features are saved under `data/users` and + * product features are saved under `data/products`. + */ + def save(model: MatrixFactorizationModel, path: String): Unit = { + val sc = model.userFeatures.sparkContext + val sqlContext = new SQLContext(sc) + import sqlContext.implicits._ + val metadata = compact(render( + ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ ("rank" -> model.rank))) + sc.parallelize(Seq(metadata), 1).saveAsTextFile(metadataPath(path)) + model.userFeatures.toDF("id", "features").saveAsParquetFile(userPath(path)) + model.productFeatures.toDF("id", "features").saveAsParquetFile(productPath(path)) + } + + def load(sc: SparkContext, path: String): MatrixFactorizationModel = { + implicit val formats = DefaultFormats + val sqlContext = new SQLContext(sc) + val (className, formatVersion, metadata) = loadMetadata(sc, path) + assert(className == thisClassName) + assert(formatVersion == thisFormatVersion) + val rank = (metadata \ "rank").extract[Int] + val userFeatures = sqlContext.parquetFile(userPath(path)) + .map { case Row(id: Int, features: Seq[Double]) => + (id, features.toArray) + } + val productFeatures = sqlContext.parquetFile(productPath(path)) + .map { case Row(id: Int, features: Seq[Double]) => + (id, features.toArray) + } + new MatrixFactorizationModel(rank, userFeatures, productFeatures) + } + + private def userPath(path: String): String = { + new Path(dataPath(path), "user").toUri.toString + } + + private def productPath(path: String): String = { + new Path(dataPath(path), "product").toUri.toString + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala index 0287f04e2c777..7c66e8cdebdbe 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala @@ -98,6 +98,23 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] protected var validateData: Boolean = true + /** + * In `GeneralizedLinearModel`, only single linear predictor is allowed for both weights + * and intercept. However, for multinomial logistic regression, with K possible outcomes, + * we are training K-1 independent binary logistic regression models which requires K-1 sets + * of linear predictor. + * + * As a result, the workaround here is if more than two sets of linear predictors are needed, + * we construct bigger `weights` vector which can hold both weights and intercepts. + * If the intercepts are added, the dimension of `weights` will be + * (numOfLinearPredictor) * (numFeatures + 1) . If the intercepts are not added, + * the dimension of `weights` will be (numOfLinearPredictor) * numFeatures. + * + * Thus, the intercepts will be encapsulated into weights, and we leave the value of intercept + * in GeneralizedLinearModel as zero. + */ + protected var numOfLinearPredictor: Int = 1 + /** * Whether to perform feature scaling before model training to reduce the condition numbers * which can significantly help the optimizer converging faster. The scaling correction will be @@ -106,6 +123,11 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] */ private var useFeatureScaling = false + /** + * The dimension of training features. + */ + protected var numFeatures: Int = -1 + /** * Set if the algorithm should use feature scaling to improve the convergence during optimization. */ @@ -141,8 +163,30 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] * RDD of LabeledPoint entries. */ def run(input: RDD[LabeledPoint]): M = { - val numFeatures: Int = input.first().features.size - val initialWeights = Vectors.dense(new Array[Double](numFeatures)) + if (numFeatures < 0) { + numFeatures = input.map(_.features.size).first() + } + + /** + * When `numOfLinearPredictor > 1`, the intercepts are encapsulated into weights, + * so the `weights` will include the intercepts. When `numOfLinearPredictor == 1`, + * the intercept will be stored as separated value in `GeneralizedLinearModel`. + * This will result in different behaviors since when `numOfLinearPredictor == 1`, + * users have no way to set the initial intercept, while in the other case, users + * can set the intercepts as part of weights. + * + * TODO: See if we can deprecate `intercept` in `GeneralizedLinearModel`, and always + * have the intercept as part of weights to have consistent design. + */ + val initialWeights = { + if (numOfLinearPredictor == 1) { + Vectors.dense(new Array[Double](numFeatures)) + } else if (addIntercept) { + Vectors.dense(new Array[Double]((numFeatures + 1) * numOfLinearPredictor)) + } else { + Vectors.dense(new Array[Double](numFeatures * numOfLinearPredictor)) + } + } run(input, initialWeights) } @@ -162,7 +206,7 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] throw new SparkException("Input validation failed.") } - /** + /* * Scaling columns to unit variance as a heuristic to reduce the condition number: * * During the optimization process, the convergence (rate) depends on the condition number of @@ -182,42 +226,53 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] * Currently, it's only enabled in LogisticRegressionWithLBFGS */ val scaler = if (useFeatureScaling) { - (new StandardScaler).fit(input.map(x => x.features)) + new StandardScaler(withStd = true, withMean = false).fit(input.map(_.features)) } else { null } // Prepend an extra variable consisting of all 1.0's for the intercept. - val data = if (addIntercept) { - if(useFeatureScaling) { - input.map(labeledPoint => - (labeledPoint.label, appendBias(scaler.transform(labeledPoint.features)))) - } else { - input.map(labeledPoint => (labeledPoint.label, appendBias(labeledPoint.features))) - } - } else { - if (useFeatureScaling) { - input.map(labeledPoint => (labeledPoint.label, scaler.transform(labeledPoint.features))) + // TODO: Apply feature scaling to the weight vector instead of input data. + val data = + if (addIntercept) { + if (useFeatureScaling) { + input.map(lp => (lp.label, appendBias(scaler.transform(lp.features)))).cache() + } else { + input.map(lp => (lp.label, appendBias(lp.features))).cache() + } } else { - input.map(labeledPoint => (labeledPoint.label, labeledPoint.features)) + if (useFeatureScaling) { + input.map(lp => (lp.label, scaler.transform(lp.features))).cache() + } else { + input.map(lp => (lp.label, lp.features)) + } } - } - val initialWeightsWithIntercept = if (addIntercept) { + /** + * TODO: For better convergence, in logistic regression, the intercepts should be computed + * from the prior probability distribution of the outcomes; for linear regression, + * the intercept should be set as the average of response. + */ + val initialWeightsWithIntercept = if (addIntercept && numOfLinearPredictor == 1) { appendBias(initialWeights) } else { + /** If `numOfLinearPredictor > 1`, initialWeights already contains intercepts. */ initialWeights } val weightsWithIntercept = optimizer.optimize(data, initialWeightsWithIntercept) - val intercept = if (addIntercept) weightsWithIntercept(weightsWithIntercept.size - 1) else 0.0 - var weights = - if (addIntercept) { - Vectors.dense(weightsWithIntercept.toArray.slice(0, weightsWithIntercept.size - 1)) - } else { - weightsWithIntercept - } + val intercept = if (addIntercept && numOfLinearPredictor == 1) { + weightsWithIntercept(weightsWithIntercept.size - 1) + } else { + 0.0 + } + + var weights = if (addIntercept && numOfLinearPredictor == 1) { + Vectors.dense(weightsWithIntercept.toArray.slice(0, weightsWithIntercept.size - 1)) + } else { + weightsWithIntercept + } /** * The weights and intercept are trained in the scaled space; we're converting them back to @@ -228,7 +283,29 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] * is the coefficient in the original space, and v_i is the variance of the column i. */ if (useFeatureScaling) { - weights = scaler.transform(weights) + if (numOfLinearPredictor == 1) { + weights = scaler.transform(weights) + } else { + /** + * For `numOfLinearPredictor > 1`, we have to transform the weights back to the original + * scale for each set of linear predictor. Note that the intercepts have to be explicitly + * excluded when `addIntercept == true` since the intercepts are part of weights now. + */ + var i = 0 + val n = weights.size / numOfLinearPredictor + val weightsArray = weights.toArray + while (i < numOfLinearPredictor) { + val start = i * n + val end = (i + 1) * n - { if (addIntercept) 1 else 0 } + + val partialWeightsArray = scaler.transform( + Vectors.dense(weightsArray.slice(start, end))).toArray + + System.arraycopy(partialWeightsArray, 0, weightsArray, start, partialWeightsArray.size) + i += 1 + } + weights = Vectors.dense(weightsArray) + } } // Warn at the end of the run as well, for increased visibility. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala new file mode 100644 index 0000000000000..cb70852e3cc8d --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala @@ -0,0 +1,313 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.regression + +import java.io.Serializable +import java.lang.{Double => JDouble} +import java.util.Arrays.binarySearch + +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.annotation.Experimental +import org.apache.spark.api.java.{JavaDoubleRDD, JavaRDD} +import org.apache.spark.rdd.RDD + +/** + * :: Experimental :: + * + * Regression model for isotonic regression. + * + * @param boundaries Array of boundaries for which predictions are known. + * Boundaries must be sorted in increasing order. + * @param predictions Array of predictions associated to the boundaries at the same index. + * Results of isotonic regression and therefore monotone. + * @param isotonic indicates whether this is isotonic or antitonic. + */ +@Experimental +class IsotonicRegressionModel ( + val boundaries: Array[Double], + val predictions: Array[Double], + val isotonic: Boolean) extends Serializable { + + private val predictionOrd = if (isotonic) Ordering[Double] else Ordering[Double].reverse + + require(boundaries.length == predictions.length) + assertOrdered(boundaries) + assertOrdered(predictions)(predictionOrd) + + /** Asserts the input array is monotone with the given ordering. */ + private def assertOrdered(xs: Array[Double])(implicit ord: Ordering[Double]): Unit = { + var i = 1 + while (i < xs.length) { + require(ord.compare(xs(i - 1), xs(i)) <= 0, + s"Elements (${xs(i - 1)}, ${xs(i)}) are not ordered.") + i += 1 + } + } + + /** + * Predict labels for provided features. + * Using a piecewise linear function. + * + * @param testData Features to be labeled. + * @return Predicted labels. + */ + def predict(testData: RDD[Double]): RDD[Double] = { + testData.map(predict) + } + + /** + * Predict labels for provided features. + * Using a piecewise linear function. + * + * @param testData Features to be labeled. + * @return Predicted labels. + */ + def predict(testData: JavaDoubleRDD): JavaDoubleRDD = { + JavaDoubleRDD.fromRDD(predict(testData.rdd.retag.asInstanceOf[RDD[Double]])) + } + + /** + * Predict a single label. + * Using a piecewise linear function. + * + * @param testData Feature to be labeled. + * @return Predicted label. + * 1) If testData exactly matches a boundary then associated prediction is returned. + * In case there are multiple predictions with the same boundary then one of them + * is returned. Which one is undefined (same as java.util.Arrays.binarySearch). + * 2) If testData is lower or higher than all boundaries then first or last prediction + * is returned respectively. In case there are multiple predictions with the same + * boundary then the lowest or highest is returned respectively. + * 3) If testData falls between two values in boundary array then prediction is treated + * as piecewise linear function and interpolated value is returned. In case there are + * multiple values with the same boundary then the same rules as in 2) are used. + */ + def predict(testData: Double): Double = { + + def linearInterpolation(x1: Double, y1: Double, x2: Double, y2: Double, x: Double): Double = { + y1 + (y2 - y1) * (x - x1) / (x2 - x1) + } + + val foundIndex = binarySearch(boundaries, testData) + val insertIndex = -foundIndex - 1 + + // Find if the index was lower than all values, + // higher than all values, in between two values or exact match. + if (insertIndex == 0) { + predictions.head + } else if (insertIndex == boundaries.length){ + predictions.last + } else if (foundIndex < 0) { + linearInterpolation( + boundaries(insertIndex - 1), + predictions(insertIndex - 1), + boundaries(insertIndex), + predictions(insertIndex), + testData) + } else { + predictions(foundIndex) + } + } +} + +/** + * :: Experimental :: + * + * Isotonic regression. + * Currently implemented using parallelized pool adjacent violators algorithm. + * Only univariate (single feature) algorithm supported. + * + * Sequential PAV implementation based on: + * Tibshirani, Ryan J., Holger Hoefling, and Robert Tibshirani. + * "Nearly-isotonic regression." Technometrics 53.1 (2011): 54-61. + * Available from [[http://www.stat.cmu.edu/~ryantibs/papers/neariso.pdf]] + * + * Sequential PAV parallelization based on: + * Kearsley, Anthony J., Richard A. Tapia, and Michael W. Trosset. + * "An approach to parallelizing isotonic regression." + * Applied Mathematics and Parallel Computing. Physica-Verlag HD, 1996. 141-147. + * Available from [[http://softlib.rice.edu/pub/CRPC-TRs/reports/CRPC-TR96640.pdf]] + * + * @see [[http://en.wikipedia.org/wiki/Isotonic_regression Isotonic regression (Wikipedia)]] + */ +@Experimental +class IsotonicRegression private (private var isotonic: Boolean) extends Serializable { + + /** + * Constructs IsotonicRegression instance with default parameter isotonic = true. + * + * @return New instance of IsotonicRegression. + */ + def this() = this(true) + + /** + * Sets the isotonic parameter. + * + * @param isotonic Isotonic (increasing) or antitonic (decreasing) sequence. + * @return This instance of IsotonicRegression. + */ + def setIsotonic(isotonic: Boolean): this.type = { + this.isotonic = isotonic + this + } + + /** + * Run IsotonicRegression algorithm to obtain isotonic regression model. + * + * @param input RDD of tuples (label, feature, weight) where label is dependent variable + * for which we calculate isotonic regression, feature is independent variable + * and weight represents number of measures with default 1. + * If multiple labels share the same feature value then they are ordered before + * the algorithm is executed. + * @return Isotonic regression model. + */ + def run(input: RDD[(Double, Double, Double)]): IsotonicRegressionModel = { + val preprocessedInput = if (isotonic) { + input + } else { + input.map(x => (-x._1, x._2, x._3)) + } + + val pooled = parallelPoolAdjacentViolators(preprocessedInput) + + val predictions = if (isotonic) pooled.map(_._1) else pooled.map(-_._1) + val boundaries = pooled.map(_._2) + + new IsotonicRegressionModel(boundaries, predictions, isotonic) + } + + /** + * Run pool adjacent violators algorithm to obtain isotonic regression model. + * + * @param input JavaRDD of tuples (label, feature, weight) where label is dependent variable + * for which we calculate isotonic regression, feature is independent variable + * and weight represents number of measures with default 1. + * If multiple labels share the same feature value then they are ordered before + * the algorithm is executed. + * @return Isotonic regression model. + */ + def run(input: JavaRDD[(JDouble, JDouble, JDouble)]): IsotonicRegressionModel = { + run(input.rdd.retag.asInstanceOf[RDD[(Double, Double, Double)]]) + } + + /** + * Performs a pool adjacent violators algorithm (PAV). + * Uses approach with single processing of data where violators + * in previously processed data created by pooling are fixed immediately. + * Uses optimization of discovering monotonicity violating sequences (blocks). + * + * @param input Input data of tuples (label, feature, weight). + * @return Result tuples (label, feature, weight) where labels were updated + * to form a monotone sequence as per isotonic regression definition. + */ + private def poolAdjacentViolators( + input: Array[(Double, Double, Double)]): Array[(Double, Double, Double)] = { + + if (input.isEmpty) { + return Array.empty + } + + // Pools sub array within given bounds assigning weighted average value to all elements. + def pool(input: Array[(Double, Double, Double)], start: Int, end: Int): Unit = { + val poolSubArray = input.slice(start, end + 1) + + val weightedSum = poolSubArray.map(lp => lp._1 * lp._3).sum + val weight = poolSubArray.map(_._3).sum + + var i = start + while (i <= end) { + input(i) = (weightedSum / weight, input(i)._2, input(i)._3) + i = i + 1 + } + } + + var i = 0 + while (i < input.length) { + var j = i + + // Find monotonicity violating sequence, if any. + while (j < input.length - 1 && input(j)._1 > input(j + 1)._1) { + j = j + 1 + } + + // If monotonicity was not violated, move to next data point. + if (i == j) { + i = i + 1 + } else { + // Otherwise pool the violating sequence + // and check if pooling caused monotonicity violation in previously processed points. + while (i >= 0 && input(i)._1 > input(i + 1)._1) { + pool(input, i, j) + i = i - 1 + } + + i = j + } + } + + // For points having the same prediction, we only keep two boundary points. + val compressed = ArrayBuffer.empty[(Double, Double, Double)] + + var (curLabel, curFeature, curWeight) = input.head + var rightBound = curFeature + def merge(): Unit = { + compressed += ((curLabel, curFeature, curWeight)) + if (rightBound > curFeature) { + compressed += ((curLabel, rightBound, 0.0)) + } + } + i = 1 + while (i < input.length) { + val (label, feature, weight) = input(i) + if (label == curLabel) { + curWeight += weight + rightBound = feature + } else { + merge() + curLabel = label + curFeature = feature + curWeight = weight + rightBound = curFeature + } + i += 1 + } + merge() + + compressed.toArray + } + + /** + * Performs parallel pool adjacent violators algorithm. + * Performs Pool adjacent violators algorithm on each partition and then again on the result. + * + * @param input Input data of tuples (label, feature, weight). + * @return Result tuples (label, feature, weight) where labels were updated + * to form a monotone sequence as per isotonic regression definition. + */ + private def parallelPoolAdjacentViolators( + input: RDD[(Double, Double, Double)]): Array[(Double, Double, Double)] = { + val parallelStepResult = input + .sortBy(x => (x._2, x._1)) + .glom() + .flatMap(poolAdjacentViolators) + .collect() + .sortBy(x => (x._2, x._1)) // Sort again because collect() doesn't promise ordering. + poolAdjacentViolators(parallelStepResult) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala index 8ecd5c6ad93c0..e8b03816573cf 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala @@ -17,9 +17,11 @@ package org.apache.spark.mllib.regression -import org.apache.spark.annotation.Experimental +import org.apache.spark.SparkContext import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.optimization._ +import org.apache.spark.mllib.regression.impl.GLMRegressionModel +import org.apache.spark.mllib.util.{Saveable, Loader} import org.apache.spark.rdd.RDD /** @@ -32,7 +34,7 @@ class LassoModel ( override val weights: Vector, override val intercept: Double) extends GeneralizedLinearModel(weights, intercept) - with RegressionModel with Serializable { + with RegressionModel with Serializable with Saveable { override protected def predictPoint( dataMatrix: Vector, @@ -40,12 +42,37 @@ class LassoModel ( intercept: Double): Double = { weightMatrix.toBreeze.dot(dataMatrix.toBreeze) + intercept } + + override def save(sc: SparkContext, path: String): Unit = { + GLMRegressionModel.SaveLoadV1_0.save(sc, path, this.getClass.getName, weights, intercept) + } + + override protected def formatVersion: String = "1.0" +} + +object LassoModel extends Loader[LassoModel] { + + override def load(sc: SparkContext, path: String): LassoModel = { + val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path) + // Hard-code class name string in case it changes in the future + val classNameV1_0 = "org.apache.spark.mllib.regression.LassoModel" + (loadedClassName, version) match { + case (className, "1.0") if className == classNameV1_0 => + val numFeatures = RegressionModel.getNumFeatures(metadata) + val data = GLMRegressionModel.SaveLoadV1_0.loadData(sc, path, classNameV1_0, numFeatures) + new LassoModel(data.weights, data.intercept) + case _ => throw new Exception( + s"LassoModel.load did not recognize model with (className, format version):" + + s"($loadedClassName, $version). Supported:\n" + + s" ($classNameV1_0, 1.0)") + } + } } /** * Train a regression model with L1-regularization using Stochastic Gradient Descent. * This solves the l1-regularized least squares regression formulation - * f(weights) = 1/2n ||A weights-y||^2 + regParam ||weights||_1 + * f(weights) = 1/2n ||A weights-y||^2^ + regParam ||weights||_1 * Here the data matrix has n rows, and the input RDD holds the set of rows of A, each with * its corresponding right hand side label y. * See also the documentation for the precise formulation. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala index 81b6598377ff5..6fa7ad52a5b33 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala @@ -17,9 +17,12 @@ package org.apache.spark.mllib.regression -import org.apache.spark.rdd.RDD +import org.apache.spark.SparkContext import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.optimization._ +import org.apache.spark.mllib.regression.impl.GLMRegressionModel +import org.apache.spark.mllib.util.{Saveable, Loader} +import org.apache.spark.rdd.RDD /** * Regression model trained using LinearRegression. @@ -30,7 +33,8 @@ import org.apache.spark.mllib.optimization._ class LinearRegressionModel ( override val weights: Vector, override val intercept: Double) - extends GeneralizedLinearModel(weights, intercept) with RegressionModel with Serializable { + extends GeneralizedLinearModel(weights, intercept) with RegressionModel with Serializable + with Saveable { override protected def predictPoint( dataMatrix: Vector, @@ -38,12 +42,37 @@ class LinearRegressionModel ( intercept: Double): Double = { weightMatrix.toBreeze.dot(dataMatrix.toBreeze) + intercept } + + override def save(sc: SparkContext, path: String): Unit = { + GLMRegressionModel.SaveLoadV1_0.save(sc, path, this.getClass.getName, weights, intercept) + } + + override protected def formatVersion: String = "1.0" +} + +object LinearRegressionModel extends Loader[LinearRegressionModel] { + + override def load(sc: SparkContext, path: String): LinearRegressionModel = { + val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path) + // Hard-code class name string in case it changes in the future + val classNameV1_0 = "org.apache.spark.mllib.regression.LinearRegressionModel" + (loadedClassName, version) match { + case (className, "1.0") if className == classNameV1_0 => + val numFeatures = RegressionModel.getNumFeatures(metadata) + val data = GLMRegressionModel.SaveLoadV1_0.loadData(sc, path, classNameV1_0, numFeatures) + new LinearRegressionModel(data.weights, data.intercept) + case _ => throw new Exception( + s"LinearRegressionModel.load did not recognize model with (className, format version):" + + s"($loadedClassName, $version). Supported:\n" + + s" ($classNameV1_0, 1.0)") + } + } } /** * Train a linear regression model with no regularization using Stochastic Gradient Descent. * This solves the least squares regression formulation - * f(weights) = 1/n ||A weights-y||^2 + * f(weights) = 1/n ||A weights-y||^2^ * (which is the mean squared error). * Here the data matrix has n rows, and the input RDD holds the set of rows of A, each with * its corresponding right hand side label y. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionModel.scala index 64b02f7a6e7a9..214ac4d0ed7dd 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionModel.scala @@ -17,10 +17,12 @@ package org.apache.spark.mllib.regression +import org.json4s.{DefaultFormats, JValue} + import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaRDD -import org.apache.spark.rdd.RDD import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.rdd.RDD @Experimental trait RegressionModel extends Serializable { @@ -48,3 +50,15 @@ trait RegressionModel extends Serializable { def predict(testData: JavaRDD[Vector]): JavaRDD[java.lang.Double] = predict(testData.rdd).toJavaRDD().asInstanceOf[JavaRDD[java.lang.Double]] } + +private[mllib] object RegressionModel { + + /** + * Helper method for loading GLM regression model metadata. + * @return numFeatures + */ + def getNumFeatures(metadata: JValue): Int = { + implicit val formats = DefaultFormats + (metadata \ "numFeatures").extract[Int] + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala index 076ba35051c9d..8838ca8c14718 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala @@ -17,10 +17,13 @@ package org.apache.spark.mllib.regression -import org.apache.spark.annotation.Experimental -import org.apache.spark.rdd.RDD -import org.apache.spark.mllib.optimization._ +import org.apache.spark.SparkContext import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.optimization._ +import org.apache.spark.mllib.regression.impl.GLMRegressionModel +import org.apache.spark.mllib.util.{Loader, Saveable} +import org.apache.spark.rdd.RDD + /** * Regression model trained using RidgeRegression. @@ -32,7 +35,7 @@ class RidgeRegressionModel ( override val weights: Vector, override val intercept: Double) extends GeneralizedLinearModel(weights, intercept) - with RegressionModel with Serializable { + with RegressionModel with Serializable with Saveable { override protected def predictPoint( dataMatrix: Vector, @@ -40,12 +43,37 @@ class RidgeRegressionModel ( intercept: Double): Double = { weightMatrix.toBreeze.dot(dataMatrix.toBreeze) + intercept } + + override def save(sc: SparkContext, path: String): Unit = { + GLMRegressionModel.SaveLoadV1_0.save(sc, path, this.getClass.getName, weights, intercept) + } + + override protected def formatVersion: String = "1.0" +} + +object RidgeRegressionModel extends Loader[RidgeRegressionModel] { + + override def load(sc: SparkContext, path: String): RidgeRegressionModel = { + val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path) + // Hard-code class name string in case it changes in the future + val classNameV1_0 = "org.apache.spark.mllib.regression.RidgeRegressionModel" + (loadedClassName, version) match { + case (className, "1.0") if className == classNameV1_0 => + val numFeatures = RegressionModel.getNumFeatures(metadata) + val data = GLMRegressionModel.SaveLoadV1_0.loadData(sc, path, classNameV1_0, numFeatures) + new RidgeRegressionModel(data.weights, data.intercept) + case _ => throw new Exception( + s"RidgeRegressionModel.load did not recognize model with (className, format version):" + + s"($loadedClassName, $version). Supported:\n" + + s" ($classNameV1_0, 1.0)") + } + } } /** * Train a regression model with L2-regularization using Stochastic Gradient Descent. * This solves the l1-regularized least squares regression formulation - * f(weights) = 1/2n ||A weights-y||^2 + regParam/2 ||weights||^2 + * f(weights) = 1/2n ||A weights-y||^2^ + regParam/2 ||weights||^2^ * Here the data matrix has n rows, and the input RDD holds the set of rows of A, each with * its corresponding right hand side label y. * See also the documentation for the precise formulation. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala index b549b7c475fc3..ce95c063db970 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala @@ -21,7 +21,9 @@ import scala.reflect.ClassTag import org.apache.spark.Logging import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.api.java.JavaSparkContext.fakeClassTag +import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.streaming.api.java.{JavaDStream, JavaPairDStream} import org.apache.spark.streaming.dstream.DStream /** @@ -39,14 +41,14 @@ import org.apache.spark.streaming.dstream.DStream * * For example usage, see `StreamingLinearRegressionWithSGD`. * - * NOTE(Freeman): In some use cases, the order in which trainOn and predictOn + * NOTE: In some use cases, the order in which trainOn and predictOn * are called in an application will affect the results. When called on * the same DStream, if trainOn is called before predictOn, when new data * arrive the model will update and the prediction will be based on the new * model. Whereas if predictOn is called first, the prediction will use the model * from the previous update. * - * NOTE(Freeman): It is ok to call predictOn repeatedly on multiple streams; this + * NOTE: It is ok to call predictOn repeatedly on multiple streams; this * will generate predictions for each one all using the current model. * It is also ok to call trainOn on different streams; this will update * the model using each of the different sources, in sequence. @@ -58,14 +60,14 @@ abstract class StreamingLinearAlgorithm[ A <: GeneralizedLinearAlgorithm[M]] extends Logging { /** The model to be updated and used for prediction. */ - protected var model: M + protected var model: Option[M] = None /** The algorithm to use for updating. */ protected val algorithm: A /** Return the latest model. */ def latestModel(): M = { - model + model.get } /** @@ -76,22 +78,32 @@ abstract class StreamingLinearAlgorithm[ * * @param data DStream containing labeled data */ - def trainOn(data: DStream[LabeledPoint]) { - if (Option(model.weights) == None) { - logError("Initial weights must be set before starting training") - throw new IllegalArgumentException + def trainOn(data: DStream[LabeledPoint]): Unit = { + if (model.isEmpty) { + throw new IllegalArgumentException("Model must be initialized before starting training.") } data.foreachRDD { (rdd, time) => - model = algorithm.run(rdd, model.weights) - logInfo("Model updated at time %s".format(time.toString)) - val display = model.weights.size match { - case x if x > 100 => model.weights.toArray.take(100).mkString("[", ",", "...") - case _ => model.weights.toArray.mkString("[", ",", "]") + val initialWeights = + model match { + case Some(m) => + m.weights + case None => + val numFeatures = rdd.first().features.size + Vectors.dense(numFeatures) } - logInfo("Current model: weights, %s".format (display)) + model = Some(algorithm.run(rdd, initialWeights)) + logInfo("Model updated at time %s".format(time.toString)) + val display = model.get.weights.size match { + case x if x > 100 => model.get.weights.toArray.take(100).mkString("[", ",", "...") + case _ => model.get.weights.toArray.mkString("[", ",", "]") + } + logInfo("Current model: weights, %s".format (display)) } } + /** Java-friendly version of `trainOn`. */ + def trainOn(data: JavaDStream[LabeledPoint]): Unit = trainOn(data.dstream) + /** * Use the model to make predictions on batches of data from a DStream * @@ -99,12 +111,15 @@ abstract class StreamingLinearAlgorithm[ * @return DStream containing predictions */ def predictOn(data: DStream[Vector]): DStream[Double] = { - if (Option(model.weights) == None) { - val msg = "Initial weights must be set before starting prediction" - logError(msg) - throw new IllegalArgumentException(msg) + if (model.isEmpty) { + throw new IllegalArgumentException("Model must be initialized before starting prediction.") } - data.map(model.predict) + data.map(model.get.predict) + } + + /** Java-friendly version of `predictOn`. */ + def predictOn(data: JavaDStream[Vector]): JavaDStream[java.lang.Double] = { + JavaDStream.fromDStream(predictOn(data.dstream).asInstanceOf[DStream[java.lang.Double]]) } /** @@ -114,11 +129,17 @@ abstract class StreamingLinearAlgorithm[ * @return DStream containing the input keys and the predictions as values */ def predictOnValues[K: ClassTag](data: DStream[(K, Vector)]): DStream[(K, Double)] = { - if (Option(model.weights) == None) { - val msg = "Initial weights must be set before starting prediction" - logError(msg) - throw new IllegalArgumentException(msg) + if (model.isEmpty) { + throw new IllegalArgumentException("Model must be initialized before starting prediction") } - data.mapValues(model.predict) + data.mapValues(model.get.predict) + } + + + /** Java-friendly version of `predictOnValues`. */ + def predictOnValues[K](data: JavaPairDStream[K, Vector]): JavaPairDStream[K, java.lang.Double] = { + implicit val tag = fakeClassTag[K] + JavaPairDStream.fromPairDStream( + predictOnValues(data.dstream).asInstanceOf[DStream[(K, java.lang.Double)]]) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala index 1d11fde24712c..e5e6301127a28 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala @@ -21,6 +21,7 @@ import org.apache.spark.annotation.Experimental import org.apache.spark.mllib.linalg.Vector /** + * :: Experimental :: * Train or predict a linear regression model on streaming data. Training uses * Stochastic Gradient Descent to update the model based on each new batch of * incoming data from a DStream (see `LinearRegressionWithSGD` for model equation) @@ -41,13 +42,12 @@ import org.apache.spark.mllib.linalg.Vector * */ @Experimental -class StreamingLinearRegressionWithSGD ( +class StreamingLinearRegressionWithSGD private[mllib] ( private var stepSize: Double, private var numIterations: Int, - private var miniBatchFraction: Double, - private var initialWeights: Vector) - extends StreamingLinearAlgorithm[ - LinearRegressionModel, LinearRegressionWithSGD] with Serializable { + private var miniBatchFraction: Double) + extends StreamingLinearAlgorithm[LinearRegressionModel, LinearRegressionWithSGD] + with Serializable { /** * Construct a StreamingLinearRegression object with default parameters: @@ -55,12 +55,10 @@ class StreamingLinearRegressionWithSGD ( * Initial weights must be set before using trainOn or predictOn * (see `StreamingLinearAlgorithm`) */ - def this() = this(0.1, 50, 1.0, null) + def this() = this(0.1, 50, 1.0) val algorithm = new LinearRegressionWithSGD(stepSize, numIterations, miniBatchFraction) - var model = algorithm.createModel(initialWeights, 0.0) - /** Set the step size for gradient descent. Default: 0.1. */ def setStepSize(stepSize: Double): this.type = { this.algorithm.optimizer.setStepSize(stepSize) @@ -81,7 +79,7 @@ class StreamingLinearRegressionWithSGD ( /** Set the initial weights. Default: [0.0, 0.0]. */ def setInitialWeights(initialWeights: Vector): this.type = { - this.model = algorithm.createModel(initialWeights, 0.0) + this.model = Some(algorithm.createModel(initialWeights, 0.0)) this } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala new file mode 100644 index 0000000000000..bd7e340ca2d8e --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.regression.impl + +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods._ + +import org.apache.spark.SparkContext +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.util.Loader +import org.apache.spark.sql.{DataFrame, Row, SQLContext} + +/** + * Helper methods for import/export of GLM regression models. + */ +private[regression] object GLMRegressionModel { + + object SaveLoadV1_0 { + + def thisFormatVersion = "1.0" + + /** Model data for model import/export */ + case class Data(weights: Vector, intercept: Double) + + /** + * Helper method for saving GLM regression model metadata and data. + * @param modelClass String name for model class, to be saved with metadata + */ + def save( + sc: SparkContext, + path: String, + modelClass: String, + weights: Vector, + intercept: Double): Unit = { + val sqlContext = new SQLContext(sc) + import sqlContext.implicits._ + + // Create JSON metadata. + val metadata = compact(render( + ("class" -> modelClass) ~ ("version" -> thisFormatVersion) ~ + ("numFeatures" -> weights.size))) + sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path)) + + // Create Parquet data. + val data = Data(weights, intercept) + val dataRDD: DataFrame = sc.parallelize(Seq(data), 1).toDF() + // TODO: repartition with 1 partition after SPARK-5532 gets fixed + dataRDD.saveAsParquetFile(Loader.dataPath(path)) + } + + /** + * Helper method for loading GLM regression model data. + * @param modelClass String name for model class (used for error messages) + * @param numFeatures Number of features, to be checked against loaded data. + * The length of the weights vector should equal numFeatures. + */ + def loadData(sc: SparkContext, path: String, modelClass: String, numFeatures: Int): Data = { + val datapath = Loader.dataPath(path) + val sqlContext = new SQLContext(sc) + val dataRDD = sqlContext.parquetFile(datapath) + val dataArray = dataRDD.select("weights", "intercept").take(1) + assert(dataArray.size == 1, s"Unable to load $modelClass data from: $datapath") + val data = dataArray(0) + assert(data.size == 2, s"Unable to load $modelClass data from: $datapath") + data match { + case Row(weights: Vector, intercept: Double) => + assert(weights.size == numFeatures, s"Expected $numFeatures features, but" + + s" found ${weights.size} features when loading $modelClass weights from $datapath") + Data(weights, intercept) + } + } + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/KernelDensity.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/KernelDensity.scala new file mode 100644 index 0000000000000..0deef11b4511a --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/KernelDensity.scala @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.stat + +import org.apache.spark.rdd.RDD + +private[stat] object KernelDensity { + /** + * Given a set of samples from a distribution, estimates its density at the set of given points. + * Uses a Gaussian kernel with the given standard deviation. + */ + def estimate(samples: RDD[Double], standardDeviation: Double, + evaluationPoints: Array[Double]): Array[Double] = { + if (standardDeviation <= 0.0) { + throw new IllegalArgumentException("Standard deviation must be positive") + } + + // This gets used in each Gaussian PDF computation, so compute it up front + val logStandardDeviationPlusHalfLog2Pi = + Math.log(standardDeviation) + 0.5 * Math.log(2 * Math.PI) + + val (points, count) = samples.aggregate((new Array[Double](evaluationPoints.length), 0))( + (x, y) => { + var i = 0 + while (i < evaluationPoints.length) { + x._1(i) += normPdf(y, standardDeviation, logStandardDeviationPlusHalfLog2Pi, + evaluationPoints(i)) + i += 1 + } + (x._1, i) + }, + (x, y) => { + var i = 0 + while (i < evaluationPoints.length) { + x._1(i) += y._1(i) + i += 1 + } + (x._1, x._2 + y._2) + }) + + var i = 0 + while (i < points.length) { + points(i) /= count + i += 1 + } + points + } + + private def normPdf(mean: Double, standardDeviation: Double, + logStandardDeviationPlusHalfLog2Pi: Double, x: Double): Double = { + val x0 = x - mean + val x1 = x0 / standardDeviation + val logDensity = -0.5 * x1 * x1 - logStandardDeviationPlusHalfLog2Pi + Math.exp(logDensity) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala index 3cf4e807b4cf7..32561620ac914 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala @@ -26,36 +26,32 @@ import org.apache.spark.mllib.stat.test.{ChiSqTest, ChiSqTestResult} import org.apache.spark.rdd.RDD /** + * :: Experimental :: * API for statistical functions in MLlib. */ @Experimental object Statistics { /** - * :: Experimental :: * Computes column-wise summary statistics for the input RDD[Vector]. * * @param X an RDD[Vector] for which column-wise summary statistics are to be computed. * @return [[MultivariateStatisticalSummary]] object containing column-wise summary statistics. */ - @Experimental def colStats(X: RDD[Vector]): MultivariateStatisticalSummary = { new RowMatrix(X).computeColumnSummaryStatistics() } /** - * :: Experimental :: * Compute the Pearson correlation matrix for the input RDD of Vectors. * Columns with 0 covariance produce NaN entries in the correlation matrix. * * @param X an RDD[Vector] for which the correlation matrix is to be computed. * @return Pearson correlation matrix comparing columns in X. */ - @Experimental def corr(X: RDD[Vector]): Matrix = Correlations.corrMatrix(X) /** - * :: Experimental :: * Compute the correlation matrix for the input RDD of Vectors using the specified method. * Methods currently supported: `pearson` (default), `spearman`. * @@ -69,11 +65,9 @@ object Statistics { * Supported: `pearson` (default), `spearman` * @return Correlation matrix comparing columns in X. */ - @Experimental def corr(X: RDD[Vector], method: String): Matrix = Correlations.corrMatrix(X, method) /** - * :: Experimental :: * Compute the Pearson correlation for the input RDDs. * Returns NaN if either vector has 0 variance. * @@ -84,11 +78,9 @@ object Statistics { * @param y RDD[Double] of the same cardinality as x. * @return A Double containing the Pearson correlation between the two input RDD[Double]s */ - @Experimental def corr(x: RDD[Double], y: RDD[Double]): Double = Correlations.corr(x, y) /** - * :: Experimental :: * Compute the correlation for the input RDDs using the specified method. * Methods currently supported: `pearson` (default), `spearman`. * @@ -99,14 +91,12 @@ object Statistics { * @param y RDD[Double] of the same cardinality as x. * @param method String specifying the method to use for computing correlation. * Supported: `pearson` (default), `spearman` - *@return A Double containing the correlation between the two input RDD[Double]s using the + * @return A Double containing the correlation between the two input RDD[Double]s using the * specified method. */ - @Experimental def corr(x: RDD[Double], y: RDD[Double], method: String): Double = Correlations.corr(x, y, method) /** - * :: Experimental :: * Conduct Pearson's chi-squared goodness of fit test of the observed data against the * expected distribution. * @@ -120,13 +110,11 @@ object Statistics { * @return ChiSquaredTest object containing the test statistic, degrees of freedom, p-value, * the method used, and the null hypothesis. */ - @Experimental def chiSqTest(observed: Vector, expected: Vector): ChiSqTestResult = { ChiSqTest.chiSquared(observed, expected) } /** - * :: Experimental :: * Conduct Pearson's chi-squared goodness of fit test of the observed data against the uniform * distribution, with each category having an expected frequency of `1 / observed.size`. * @@ -136,11 +124,9 @@ object Statistics { * @return ChiSquaredTest object containing the test statistic, degrees of freedom, p-value, * the method used, and the null hypothesis. */ - @Experimental def chiSqTest(observed: Vector): ChiSqTestResult = ChiSqTest.chiSquared(observed) /** - * :: Experimental :: * Conduct Pearson's independence test on the input contingency matrix, which cannot contain * negative entries or columns or rows that sum up to 0. * @@ -148,11 +134,9 @@ object Statistics { * @return ChiSquaredTest object containing the test statistic, degrees of freedom, p-value, * the method used, and the null hypothesis. */ - @Experimental def chiSqTest(observed: Matrix): ChiSqTestResult = ChiSqTest.chiSquaredMatrix(observed) /** - * :: Experimental :: * Conduct Pearson's independence test for every feature against the label across the input RDD. * For each feature, the (feature, label) pairs are converted into a contingency matrix for which * the chi-squared statistic is computed. All label and feature values must be categorical. @@ -162,8 +146,21 @@ object Statistics { * @return an array containing the ChiSquaredTestResult for every feature against the label. * The order of the elements in the returned array reflects the order of input features. */ - @Experimental def chiSqTest(data: RDD[LabeledPoint]): Array[ChiSqTestResult] = { ChiSqTest.chiSquaredFeatures(data) } + + /** + * Given an empirical distribution defined by the input RDD of samples, estimate its density at + * each of the given evaluation points using a Gaussian kernel. + * + * @param samples The samples RDD used to define the empirical distribution. + * @param standardDeviation The standard deviation of the kernel Gaussians. + * @param evaluationPoints The points at which to estimate densities. + * @return An array the same size as evaluationPoints with the density at each point. + */ + def kernelDensity(samples: RDD[Double], standardDeviation: Double, + evaluationPoints: Iterable[Double]): Array[Double] = { + KernelDensity.estimate(samples, standardDeviation, evaluationPoints.toArray) + } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala index fd186b5ee6f72..cd6add9d60b0d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala @@ -17,7 +17,7 @@ package org.apache.spark.mllib.stat.distribution -import breeze.linalg.{DenseVector => DBV, DenseMatrix => DBM, diag, max, eigSym} +import breeze.linalg.{DenseVector => DBV, DenseMatrix => DBM, diag, max, eigSym, Vector => BV} import org.apache.spark.annotation.DeveloperApi; import org.apache.spark.mllib.linalg.{Vectors, Vector, Matrices, Matrix} @@ -62,21 +62,21 @@ class MultivariateGaussian ( /** Returns density of this multivariate Gaussian at given point, x */ def pdf(x: Vector): Double = { - pdf(x.toBreeze.toDenseVector) + pdf(x.toBreeze) } /** Returns the log-density of this multivariate Gaussian at given point, x */ def logpdf(x: Vector): Double = { - logpdf(x.toBreeze.toDenseVector) + logpdf(x.toBreeze) } /** Returns density of this multivariate Gaussian at given point, x */ - private[mllib] def pdf(x: DBV[Double]): Double = { + private[mllib] def pdf(x: BV[Double]): Double = { math.exp(logpdf(x)) } /** Returns the log-density of this multivariate Gaussian at given point, x */ - private[mllib] def logpdf(x: DBV[Double]): Double = { + private[mllib] def logpdf(x: BV[Double]): Double = { val delta = x - breezeMu val v = rootSigmaInv * delta u + v.t * v * -0.5 diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index b3e8ed9af8c51..b9d0c56dd1ea3 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -19,12 +19,11 @@ package org.apache.spark.mllib.tree import scala.collection.JavaConverters._ import scala.collection.mutable -import scala.collection.mutable.ArrayBuffer - +import scala.collection.mutable.ArrayBuilder +import org.apache.spark.Logging import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaRDD -import org.apache.spark.Logging import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.RandomForest.NodeIndexInfo import org.apache.spark.mllib.tree.configuration.Strategy @@ -32,13 +31,10 @@ import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.FeatureType._ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ import org.apache.spark.mllib.tree.impl._ -import org.apache.spark.mllib.tree.impurity.{Impurities, Impurity} import org.apache.spark.mllib.tree.impurity._ import org.apache.spark.mllib.tree.model._ import org.apache.spark.rdd.RDD import org.apache.spark.util.random.XORShiftRandom -import org.apache.spark.SparkContext._ - /** * :: Experimental :: @@ -331,14 +327,14 @@ object DecisionTree extends Serializable with Logging { * @param agg Array storing aggregate calculation, with a set of sufficient statistics for * each (feature, bin). * @param treePoint Data point being aggregated. - * @param bins possible bins for all features, indexed (numFeatures)(numBins) + * @param splits possible splits indexed (numFeatures)(numSplits) * @param unorderedFeatures Set of indices of unordered features. * @param instanceWeight Weight (importance) of instance in dataset. */ private def mixedBinSeqOp( agg: DTStatsAggregator, treePoint: TreePoint, - bins: Array[Array[Bin]], + splits: Array[Array[Split]], unorderedFeatures: Set[Int], instanceWeight: Double, featuresForNode: Option[Array[Int]]): Unit = { @@ -366,7 +362,7 @@ object DecisionTree extends Serializable with Logging { val numSplits = agg.metadata.numSplits(featureIndex) var splitIndex = 0 while (splitIndex < numSplits) { - if (bins(featureIndex)(splitIndex).highSplit.categories.contains(featureValue)) { + if (splits(featureIndex)(splitIndex).categories.contains(featureValue)) { agg.featureUpdate(leftNodeFeatureOffset, splitIndex, treePoint.label, instanceWeight) } else { @@ -510,8 +506,8 @@ object DecisionTree extends Serializable with Logging { if (metadata.unorderedFeatures.isEmpty) { orderedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, instanceWeight, featuresForNode) } else { - mixedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, bins, metadata.unorderedFeatures, - instanceWeight, featuresForNode) + mixedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, splits, + metadata.unorderedFeatures, instanceWeight, featuresForNode) } } } @@ -1028,35 +1024,15 @@ object DecisionTree extends Serializable with Logging { // Categorical feature val featureArity = metadata.featureArity(featureIndex) if (metadata.isUnordered(featureIndex)) { - // TODO: The second half of the bins are unused. Actually, we could just use - // splits and not build bins for unordered features. That should be part of - // a later PR since it will require changing other code (using splits instead - // of bins in a few places). // Unordered features - // 2^(maxFeatureValue - 1) - 1 combinations + // 2^(maxFeatureValue - 1) - 1 combinations splits(featureIndex) = new Array[Split](numSplits) - bins(featureIndex) = new Array[Bin](numBins) var splitIndex = 0 while (splitIndex < numSplits) { val categories: List[Double] = extractMultiClassCategories(splitIndex + 1, featureArity) splits(featureIndex)(splitIndex) = new Split(featureIndex, Double.MinValue, Categorical, categories) - bins(featureIndex)(splitIndex) = { - if (splitIndex == 0) { - new Bin( - new DummyCategoricalSplit(featureIndex, Categorical), - splits(featureIndex)(0), - Categorical, - Double.MinValue) - } else { - new Bin( - splits(featureIndex)(splitIndex - 1), - splits(featureIndex)(splitIndex), - Categorical, - Double.MinValue) - } - } splitIndex += 1 } } else { @@ -1064,8 +1040,11 @@ object DecisionTree extends Serializable with Logging { // Bins correspond to feature values, so we do not need to compute splits or bins // beforehand. Splits are constructed as needed during training. splits(featureIndex) = new Array[Split](0) - bins(featureIndex) = new Array[Bin](0) } + // For ordered features, bins correspond to feature values. + // For unordered categorical features, there is no need to construct the bins. + // since there is a one-to-one correspondence between the splits and the bins. + bins(featureIndex) = new Array[Bin](0) } featureIndex += 1 } @@ -1140,7 +1119,7 @@ object DecisionTree extends Serializable with Logging { logDebug("stride = " + stride) // iterate `valueCount` to find splits - val splits = new ArrayBuffer[Double] + val splitsBuilder = ArrayBuilder.make[Double] var index = 1 // currentCount: sum of counts of values that have been visited var currentCount = valueCounts(0)._2 @@ -1158,13 +1137,13 @@ object DecisionTree extends Serializable with Logging { // makes the gap between currentCount and targetCount smaller, // previous value is a split threshold. if (previousGap < currentGap) { - splits.append(valueCounts(index - 1)._1) + splitsBuilder += valueCounts(index - 1)._1 targetCount += stride } index += 1 } - splits.toArray + splitsBuilder.result() } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala index 61f6b1313f82e..a9c93e181e3ce 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala @@ -60,11 +60,12 @@ class GradientBoostedTrees(private val boostingStrategy: BoostingStrategy) def run(input: RDD[LabeledPoint]): GradientBoostedTreesModel = { val algo = boostingStrategy.treeStrategy.algo algo match { - case Regression => GradientBoostedTrees.boost(input, boostingStrategy) + case Regression => GradientBoostedTrees.boost(input, input, boostingStrategy, validate=false) case Classification => // Map labels to -1, +1 so binary classification can be treated as regression. val remappedInput = input.map(x => new LabeledPoint((x.label * 2) - 1, x.features)) - GradientBoostedTrees.boost(remappedInput, boostingStrategy) + GradientBoostedTrees.boost(remappedInput, + remappedInput, boostingStrategy, validate=false) case _ => throw new IllegalArgumentException(s"$algo is not supported by the gradient boosting.") } @@ -76,8 +77,46 @@ class GradientBoostedTrees(private val boostingStrategy: BoostingStrategy) def run(input: JavaRDD[LabeledPoint]): GradientBoostedTreesModel = { run(input.rdd) } -} + /** + * Method to validate a gradient boosting model + * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. + * @param validationInput Validation dataset: + RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. + Should be different from and follow the same distribution as input. + e.g., these two datasets could be created from an original dataset + by using [[org.apache.spark.rdd.RDD.randomSplit()]] + * @return a gradient boosted trees model that can be used for prediction + */ + def runWithValidation( + input: RDD[LabeledPoint], + validationInput: RDD[LabeledPoint]): GradientBoostedTreesModel = { + val algo = boostingStrategy.treeStrategy.algo + algo match { + case Regression => GradientBoostedTrees.boost( + input, validationInput, boostingStrategy, validate=true) + case Classification => + // Map labels to -1, +1 so binary classification can be treated as regression. + val remappedInput = input.map( + x => new LabeledPoint((x.label * 2) - 1, x.features)) + val remappedValidationInput = validationInput.map( + x => new LabeledPoint((x.label * 2) - 1, x.features)) + GradientBoostedTrees.boost(remappedInput, remappedValidationInput, boostingStrategy, + validate=true) + case _ => + throw new IllegalArgumentException(s"$algo is not supported by the gradient boosting.") + } + } + + /** + * Java-friendly API for [[org.apache.spark.mllib.tree.GradientBoostedTrees!#runWithValidation]]. + */ + def runWithValidation( + input: JavaRDD[LabeledPoint], + validationInput: JavaRDD[LabeledPoint]): GradientBoostedTreesModel = { + runWithValidation(input.rdd, validationInput.rdd) + } +} object GradientBoostedTrees extends Logging { @@ -108,12 +147,16 @@ object GradientBoostedTrees extends Logging { /** * Internal method for performing regression using trees as base learners. * @param input training dataset + * @param validationInput validation dataset, ignored if validate is set to false. * @param boostingStrategy boosting parameters + * @param validate whether or not to use the validation dataset. * @return a gradient boosted trees model that can be used for prediction */ private def boost( input: RDD[LabeledPoint], - boostingStrategy: BoostingStrategy): GradientBoostedTreesModel = { + validationInput: RDD[LabeledPoint], + boostingStrategy: BoostingStrategy, + validate: Boolean): GradientBoostedTreesModel = { val timer = new TimeTracker() timer.start("total") @@ -129,6 +172,7 @@ object GradientBoostedTrees extends Logging { val learningRate = boostingStrategy.learningRate // Prepare strategy for individual trees, which use regression with variance impurity. val treeStrategy = boostingStrategy.treeStrategy.copy + val validationTol = boostingStrategy.validationTol treeStrategy.algo = Regression treeStrategy.impurity = Variance treeStrategy.assertValid() @@ -152,13 +196,16 @@ object GradientBoostedTrees extends Logging { baseLearnerWeights(0) = 1.0 val startingModel = new GradientBoostedTreesModel(Regression, Array(firstTreeModel), Array(1.0)) logDebug("error of gbt = " + loss.computeError(startingModel, input)) + // Note: A model of type regression is used since we require raw prediction timer.stop("building tree 0") + var bestValidateError = if (validate) loss.computeError(startingModel, validationInput) else 0.0 + var bestM = 1 + // psuedo-residual for second iteration data = input.map(point => LabeledPoint(loss.gradient(startingModel, point), point.features)) - var m = 1 while (m < numIterations) { timer.start(s"building tree $m") @@ -177,6 +224,23 @@ object GradientBoostedTrees extends Logging { val partialModel = new GradientBoostedTreesModel( Regression, baseLearners.slice(0, m + 1), baseLearnerWeights.slice(0, m + 1)) logDebug("error of gbt = " + loss.computeError(partialModel, input)) + + if (validate) { + // Stop training early if + // 1. Reduction in error is less than the validationTol or + // 2. If the error increases, that is if the model is overfit. + // We want the model returned corresponding to the best validation error. + val currentValidateError = loss.computeError(partialModel, validationInput) + if (bestValidateError - currentValidateError < validationTol) { + return new GradientBoostedTreesModel( + boostingStrategy.treeStrategy.algo, + baseLearners.slice(0, bestM), + baseLearnerWeights.slice(0, bestM)) + } else if (currentValidateError < bestValidateError) { + bestValidateError = currentValidateError + bestM = m + 1 + } + } // Update data with pseudo-residuals data = input.map(point => LabeledPoint(-loss.gradient(partialModel, point), point.features)) @@ -187,8 +251,15 @@ object GradientBoostedTrees extends Logging { logInfo("Internal timing for DecisionTree:") logInfo(s"$timer") - - new GradientBoostedTreesModel( - boostingStrategy.treeStrategy.algo, baseLearners, baseLearnerWeights) + if (validate) { + new GradientBoostedTreesModel( + boostingStrategy.treeStrategy.algo, + baseLearners.slice(0, bestM), + baseLearnerWeights.slice(0, bestM)) + } else { + new GradientBoostedTreesModel( + boostingStrategy.treeStrategy.algo, baseLearners, baseLearnerWeights) + } } + } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala index e9304b5e5c650..db01f2e229e5a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala @@ -17,6 +17,8 @@ package org.apache.spark.mllib.tree +import java.io.IOException + import scala.collection.mutable import scala.collection.JavaConverters._ @@ -140,6 +142,7 @@ private class RandomForest ( logDebug("maxBins = " + metadata.maxBins) logDebug("featureSubsetStrategy = " + featureSubsetStrategy) logDebug("numFeaturesPerNode = " + metadata.numFeaturesPerNode) + logDebug("subsamplingRate = " + strategy.subsamplingRate) // Find the splits and the corresponding bins (interval between the splits) using a sample // of the input data. @@ -155,19 +158,12 @@ private class RandomForest ( // Cache input RDD for speedup during multiple passes. val treeInput = TreePoint.convertToTreeRDD(retaggedInput, bins, metadata) - val (subsample, withReplacement) = { - // TODO: Have a stricter check for RF in the strategy - val isRandomForest = numTrees > 1 - if (isRandomForest) { - (1.0, true) - } else { - (strategy.subsamplingRate, false) - } - } + val withReplacement = if (numTrees > 1) true else false val baggedInput - = BaggedPoint.convertToBaggedRDD(treeInput, subsample, numTrees, withReplacement, seed) - .persist(StorageLevel.MEMORY_AND_DISK) + = BaggedPoint.convertToBaggedRDD(treeInput, + strategy.subsamplingRate, numTrees, + withReplacement, seed).persist(StorageLevel.MEMORY_AND_DISK) // depth of the decision tree val maxDepth = strategy.maxDepth @@ -208,7 +204,6 @@ private class RandomForest ( Some(NodeIdCache.init( data = baggedInput, numTrees = numTrees, - checkpointDir = strategy.checkpointDir, checkpointInterval = strategy.checkpointInterval, initVal = 1)) } else { @@ -250,7 +245,12 @@ private class RandomForest ( // Delete any remaining checkpoints used for node Id cache. if (nodeIdCache.nonEmpty) { - nodeIdCache.get.deleteAllCheckpoints() + try { + nodeIdCache.get.deleteAllCheckpoints() + } catch { + case e:IOException => + logWarning(s"delete all chackpoints failed. Error reason: ${e.getMessage}") + } } val trees = topNodes.map(topNode => new DecisionTreeModel(topNode, strategy.algo)) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala index 0ef9c6181a0a0..b6099259971b7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala @@ -29,8 +29,8 @@ object Algo extends Enumeration { val Classification, Regression = Value private[mllib] def fromString(name: String): Algo = name match { - case "classification" => Classification - case "regression" => Regression + case "classification" | "Classification" => Classification + case "regression" | "Regression" => Regression case _ => throw new IllegalArgumentException(s"Did not recognize Algo name: $name") } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala index cf51d041c65a9..664c8df019233 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala @@ -34,6 +34,9 @@ import org.apache.spark.mllib.tree.loss.{LogLoss, SquaredError, Loss} * weak hypotheses used in the final model. * @param learningRate Learning rate for shrinking the contribution of each estimator. The * learning rate should be between in the interval (0, 1] + * @param validationTol Useful when runWithValidation is used. If the error rate on the + * validation input between two iterations is less than the validationTol + * then stop. Ignored when [[run]] is used. */ @Experimental case class BoostingStrategy( @@ -42,7 +45,8 @@ case class BoostingStrategy( @BeanProperty var loss: Loss, // Optional boosting parameters @BeanProperty var numIterations: Int = 100, - @BeanProperty var learningRate: Double = 0.1) extends Serializable { + @BeanProperty var learningRate: Double = 0.1, + @BeanProperty var validationTol: Double = 1e-5) extends Serializable { /** * Check validity of parameters. @@ -68,6 +72,15 @@ case class BoostingStrategy( @Experimental object BoostingStrategy { + /** + * Returns default configuration for the boosting algorithm + * @param algo Learning goal. Supported: "Classification" or "Regression" + * @return Configuration for boosting algorithm + */ + def defaultParams(algo: String): BoostingStrategy = { + defaultParams(Algo.fromString(algo)) + } + /** * Returns default configuration for the boosting algorithm * @param algo Learning goal. Supported: @@ -75,15 +88,15 @@ object BoostingStrategy { * [[org.apache.spark.mllib.tree.configuration.Algo.Regression]] * @return Configuration for boosting algorithm */ - def defaultParams(algo: String): BoostingStrategy = { - val treeStrategy = Strategy.defaultStrategy(algo) - treeStrategy.maxDepth = 3 + def defaultParams(algo: Algo): BoostingStrategy = { + val treeStragtegy = Strategy.defaultStategy(algo) + treeStragtegy.maxDepth = 3 algo match { - case "Classification" => - treeStrategy.numClasses = 2 - new BoostingStrategy(treeStrategy, LogLoss) - case "Regression" => - new BoostingStrategy(treeStrategy, SquaredError) + case Algo.Classification => + treeStragtegy.numClasses = 2 + new BoostingStrategy(treeStragtegy, LogLoss) + case Algo.Regression => + new BoostingStrategy(treeStragtegy, SquaredError) case _ => throw new IllegalArgumentException(s"$algo is not supported by boosting.") } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala index d5cd89ab94e81..8d5c36da32bdb 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -62,11 +62,10 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ * @param subsamplingRate Fraction of the training data used for learning decision tree. * @param useNodeIdCache If this is true, instead of passing trees to executors, the algorithm will * maintain a separate RDD of node Id cache for each row. - * @param checkpointDir If the node Id cache is used, it will help to checkpoint - * the node Id cache periodically. This is the checkpoint directory - * to be used for the node Id cache. * @param checkpointInterval How often to checkpoint when the node Id cache gets updated. - * E.g. 10 means that the cache will get checkpointed every 10 updates. + * E.g. 10 means that the cache will get checkpointed every 10 updates. If + * the checkpoint directory is not set in + * [[org.apache.spark.SparkContext]], this setting is ignored. */ @Experimental class Strategy ( @@ -82,7 +81,6 @@ class Strategy ( @BeanProperty var maxMemoryInMB: Int = 256, @BeanProperty var subsamplingRate: Double = 1, @BeanProperty var useNodeIdCache: Boolean = false, - @BeanProperty var checkpointDir: Option[String] = None, @BeanProperty var checkpointInterval: Int = 10) extends Serializable { def isMulticlassClassification = @@ -156,13 +154,16 @@ class Strategy ( s"DecisionTree Strategy requires minInstancesPerNode >= 1 but was given $minInstancesPerNode") require(maxMemoryInMB <= 10240, s"DecisionTree Strategy requires maxMemoryInMB <= 10240, but was given $maxMemoryInMB") + require(subsamplingRate > 0 && subsamplingRate <= 1, + s"DecisionTree Strategy requires subsamplingRate <=1 and >0, but was given " + + s"$subsamplingRate") } /** Returns a shallow copy of this instance. */ def copy: Strategy = { new Strategy(algo, impurity, maxDepth, numClasses, maxBins, quantileCalculationStrategy, categoricalFeaturesInfo, minInstancesPerNode, minInfoGain, - maxMemoryInMB, subsamplingRate, useNodeIdCache, checkpointDir, checkpointInterval) + maxMemoryInMB, subsamplingRate, useNodeIdCache, checkpointInterval) } } @@ -173,11 +174,19 @@ object Strategy { * Construct a default set of parameters for [[org.apache.spark.mllib.tree.DecisionTree]] * @param algo "Classification" or "Regression" */ - def defaultStrategy(algo: String): Strategy = algo match { - case "Classification" => + def defaultStrategy(algo: String): Strategy = { + defaultStategy(Algo.fromString(algo)) + } + + /** + * Construct a default set of parameters for [[org.apache.spark.mllib.tree.DecisionTree]] + * @param algo Algo.Classification or Algo.Regression + */ + def defaultStategy(algo: Algo): Strategy = algo match { + case Algo.Classification => new Strategy(algo = Classification, impurity = Gini, maxDepth = 10, numClasses = 2) - case "Regression" => + case Algo.Regression => new Strategy(algo = Regression, impurity = Variance, maxDepth = 10, numClasses = 0) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala index 951733fada6be..f1a6ed230186e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala @@ -183,7 +183,7 @@ private[tree] object DecisionTreeMetadata extends Logging { } /** - * Version of [[buildMetadata()]] for DecisionTree. + * Version of [[DecisionTreeMetadata#buildMetadata]] for DecisionTree. */ def buildMetadata( input: RDD[LabeledPoint], diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala index 83011b48b7d9b..bdd0f576b048d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala @@ -71,15 +71,12 @@ private[tree] case class NodeIndexUpdater( * The nodeIdsForInstances RDD needs to be updated at each iteration. * @param nodeIdsForInstances The initial values in the cache * (should be an Array of all 1's (meaning the root nodes)). - * @param checkpointDir The checkpoint directory where - * the checkpointed files will be stored. * @param checkpointInterval The checkpointing interval * (how often should the cache be checkpointed.). */ @DeveloperApi private[tree] class NodeIdCache( var nodeIdsForInstances: RDD[Array[Int]], - val checkpointDir: Option[String], val checkpointInterval: Int) { // Keep a reference to a previous node Ids for instances. @@ -91,12 +88,6 @@ private[tree] class NodeIdCache( private val checkpointQueue = mutable.Queue[RDD[Array[Int]]]() private var rddUpdateCount = 0 - // If a checkpoint directory is given, and there's no prior checkpoint directory, - // then set the checkpoint directory with the given one. - if (checkpointDir.nonEmpty && nodeIdsForInstances.sparkContext.getCheckpointDir.isEmpty) { - nodeIdsForInstances.sparkContext.setCheckpointDir(checkpointDir.get) - } - /** * Update the node index values in the cache. * This updates the RDD and its lineage. @@ -184,7 +175,6 @@ private[tree] object NodeIdCache { * Initialize the node Id cache with initial node Id values. * @param data The RDD of training rows. * @param numTrees The number of trees that we want to create cache for. - * @param checkpointDir The checkpoint directory where the checkpointed files will be stored. * @param checkpointInterval The checkpointing interval * (how often should the cache be checkpointed.). * @param initVal The initial values in the cache. @@ -193,12 +183,10 @@ private[tree] object NodeIdCache { def init( data: RDD[BaggedPoint[TreePoint]], numTrees: Int, - checkpointDir: Option[String], checkpointInterval: Int, initVal: Int = 1): NodeIdCache = { new NodeIdCache( data.map(_ => Array.fill[Int](numTrees)(initVal)), - checkpointDir, checkpointInterval) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala index 35e361ae309cc..50b292e71b067 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala @@ -55,17 +55,15 @@ private[tree] object TreePoint { input: RDD[LabeledPoint], bins: Array[Array[Bin]], metadata: DecisionTreeMetadata): RDD[TreePoint] = { - // Construct arrays for featureArity and isUnordered for efficiency in the inner loop. + // Construct arrays for featureArity for efficiency in the inner loop. val featureArity: Array[Int] = new Array[Int](metadata.numFeatures) - val isUnordered: Array[Boolean] = new Array[Boolean](metadata.numFeatures) var featureIndex = 0 while (featureIndex < metadata.numFeatures) { featureArity(featureIndex) = metadata.featureArity.getOrElse(featureIndex, 0) - isUnordered(featureIndex) = metadata.isUnordered(featureIndex) featureIndex += 1 } input.map { x => - TreePoint.labeledPointToTreePoint(x, bins, featureArity, isUnordered) + TreePoint.labeledPointToTreePoint(x, bins, featureArity) } } @@ -74,19 +72,17 @@ private[tree] object TreePoint { * @param bins Bins for features, of size (numFeatures, numBins). * @param featureArity Array indexed by feature, with value 0 for continuous and numCategories * for categorical features. - * @param isUnordered Array index by feature, with value true for unordered categorical features. */ private def labeledPointToTreePoint( labeledPoint: LabeledPoint, bins: Array[Array[Bin]], - featureArity: Array[Int], - isUnordered: Array[Boolean]): TreePoint = { + featureArity: Array[Int]): TreePoint = { val numFeatures = labeledPoint.features.size val arr = new Array[Int](numFeatures) var featureIndex = 0 while (featureIndex < numFeatures) { arr(featureIndex) = findBin(featureIndex, labeledPoint, featureArity(featureIndex), - isUnordered(featureIndex), bins) + bins) featureIndex += 1 } new TreePoint(labeledPoint.label, arr) @@ -96,14 +92,12 @@ private[tree] object TreePoint { * Find bin for one (labeledPoint, feature). * * @param featureArity 0 for continuous features; number of categories for categorical features. - * @param isUnorderedFeature (only applies if feature is categorical) * @param bins Bins for features, of size (numFeatures, numBins). */ private def findBin( featureIndex: Int, labeledPoint: LabeledPoint, featureArity: Int, - isUnorderedFeature: Boolean, bins: Array[Array[Bin]]): Int = { /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala index 0e02345aa3774..b7950e00786ab 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala @@ -94,6 +94,10 @@ private[tree] class EntropyAggregator(numClasses: Int) throw new IllegalArgumentException(s"EntropyAggregator given label $label" + s" but requires label < numClasses (= $statsSize).") } + if (label < 0) { + throw new IllegalArgumentException(s"EntropyAggregator given label $label" + + s"but requires label is non-negative.") + } allStats(offset + label.toInt) += instanceWeight } @@ -147,6 +151,7 @@ private[tree] class EntropyCalculator(stats: Array[Double]) extends ImpurityCalc val lbl = label.toInt require(lbl < stats.length, s"EntropyCalculator.prob given invalid label: $lbl (should be < ${stats.length}") + require(lbl >= 0, "Entropy does not support negative labels") val cnt = count if (cnt == 0) { 0 diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala index 7c83cd48e16a0..c946db9c0d1c8 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala @@ -90,6 +90,10 @@ private[tree] class GiniAggregator(numClasses: Int) throw new IllegalArgumentException(s"GiniAggregator given label $label" + s" but requires label < numClasses (= $statsSize).") } + if (label < 0) { + throw new IllegalArgumentException(s"GiniAggregator given label $label" + + s"but requires label is non-negative.") + } allStats(offset + label.toInt) += instanceWeight } @@ -143,6 +147,7 @@ private[tree] class GiniCalculator(stats: Array[Double]) extends ImpurityCalcula val lbl = label.toInt require(lbl < stats.length, s"GiniCalculator.prob given invalid label: $lbl (should be < ${stats.length}") + require(lbl >= 0, "GiniImpurity does not support negative labels") val cnt = count if (cnt == 0) { 0 diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala index 4bca9039ebe1d..e1169d9f66ea4 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala @@ -45,7 +45,7 @@ trait Loss extends Serializable { * purposes. * @param model Model of the weak learner. * @param data Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. - * @return + * @return Measure of model error on data */ def computeError(model: TreeEnsembleModel, data: RDD[LabeledPoint]): Double diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala index a5760963068c3..060fd5b859a51 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala @@ -17,11 +17,21 @@ package org.apache.spark.mllib.tree.model +import scala.collection.mutable + +import org.json4s._ +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods._ + +import org.apache.spark.SparkContext import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.tree.configuration.{Algo, FeatureType} import org.apache.spark.mllib.tree.configuration.Algo._ +import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{DataFrame, Row, SQLContext} /** * :: Experimental :: @@ -31,7 +41,7 @@ import org.apache.spark.rdd.RDD * @param algo algorithm type -- classification or regression */ @Experimental -class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable { +class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable with Saveable { /** * Predict values for a single data point using the model trained. @@ -53,7 +63,6 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable features.map(x => predict(x)) } - /** * Predict values for the given data set using the model trained. * @@ -99,4 +108,183 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable header + topNode.subtreeToString(2) } + override def save(sc: SparkContext, path: String): Unit = { + DecisionTreeModel.SaveLoadV1_0.save(sc, path, this) + } + + override protected def formatVersion: String = "1.0" +} + +object DecisionTreeModel extends Loader[DecisionTreeModel] { + + private[tree] object SaveLoadV1_0 { + + def thisFormatVersion = "1.0" + + // Hard-code class name string in case it changes in the future + def thisClassName = "org.apache.spark.mllib.tree.DecisionTreeModel" + + case class PredictData(predict: Double, prob: Double) { + def toPredict: Predict = new Predict(predict, prob) + } + + object PredictData { + def apply(p: Predict): PredictData = PredictData(p.predict, p.prob) + + def apply(r: Row): PredictData = PredictData(r.getDouble(0), r.getDouble(1)) + } + + case class SplitData( + feature: Int, + threshold: Double, + featureType: Int, + categories: Seq[Double]) { // TODO: Change to List once SPARK-3365 is fixed + def toSplit: Split = { + new Split(feature, threshold, FeatureType(featureType), categories.toList) + } + } + + object SplitData { + def apply(s: Split): SplitData = { + SplitData(s.feature, s.threshold, s.featureType.id, s.categories) + } + + def apply(r: Row): SplitData = { + SplitData(r.getInt(0), r.getDouble(1), r.getInt(2), r.getAs[Seq[Double]](3)) + } + } + + /** Model data for model import/export */ + case class NodeData( + treeId: Int, + nodeId: Int, + predict: PredictData, + impurity: Double, + isLeaf: Boolean, + split: Option[SplitData], + leftNodeId: Option[Int], + rightNodeId: Option[Int], + infoGain: Option[Double]) + + object NodeData { + def apply(treeId: Int, n: Node): NodeData = { + NodeData(treeId, n.id, PredictData(n.predict), n.impurity, n.isLeaf, + n.split.map(SplitData.apply), n.leftNode.map(_.id), n.rightNode.map(_.id), + n.stats.map(_.gain)) + } + + def apply(r: Row): NodeData = { + val split = if (r.isNullAt(5)) None else Some(SplitData(r.getStruct(5))) + val leftNodeId = if (r.isNullAt(6)) None else Some(r.getInt(6)) + val rightNodeId = if (r.isNullAt(7)) None else Some(r.getInt(7)) + val infoGain = if (r.isNullAt(8)) None else Some(r.getDouble(8)) + NodeData(r.getInt(0), r.getInt(1), PredictData(r.getStruct(2)), r.getDouble(3), + r.getBoolean(4), split, leftNodeId, rightNodeId, infoGain) + } + } + + def save(sc: SparkContext, path: String, model: DecisionTreeModel): Unit = { + val sqlContext = new SQLContext(sc) + import sqlContext.implicits._ + + // Create JSON metadata. + val metadata = compact(render( + ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ + ("algo" -> model.algo.toString) ~ ("numNodes" -> model.numNodes))) + sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path)) + + // Create Parquet data. + val nodes = model.topNode.subtreeIterator.toSeq + val dataRDD: DataFrame = sc.parallelize(nodes) + .map(NodeData.apply(0, _)) + .toDF() + dataRDD.saveAsParquetFile(Loader.dataPath(path)) + } + + def load(sc: SparkContext, path: String, algo: String, numNodes: Int): DecisionTreeModel = { + val datapath = Loader.dataPath(path) + val sqlContext = new SQLContext(sc) + // Load Parquet data. + val dataRDD = sqlContext.parquetFile(datapath) + // Check schema explicitly since erasure makes it hard to use match-case for checking. + Loader.checkSchema[NodeData](dataRDD.schema) + val nodes = dataRDD.map(NodeData.apply) + // Build node data into a tree. + val trees = constructTrees(nodes) + assert(trees.size == 1, + "Decision tree should contain exactly one tree but got ${trees.size} trees.") + val model = new DecisionTreeModel(trees(0), Algo.fromString(algo)) + assert(model.numNodes == numNodes, s"Unable to load DecisionTreeModel data from: $datapath." + + s" Expected $numNodes nodes but found ${model.numNodes}") + model + } + + def constructTrees(nodes: RDD[NodeData]): Array[Node] = { + val trees = nodes + .groupBy(_.treeId) + .mapValues(_.toArray) + .collect() + .map { case (treeId, data) => + (treeId, constructTree(data)) + }.sortBy(_._1) + val numTrees = trees.size + val treeIndices = trees.map(_._1).toSeq + assert(treeIndices == (0 until numTrees), + s"Tree indices must start from 0 and increment by 1, but we found $treeIndices.") + trees.map(_._2) + } + + /** + * Given a list of nodes from a tree, construct the tree. + * @param data array of all node data in a tree. + */ + def constructTree(data: Array[NodeData]): Node = { + val dataMap: Map[Int, NodeData] = data.map(n => n.nodeId -> n).toMap + assert(dataMap.contains(1), + s"DecisionTree missing root node (id = 1).") + constructNode(1, dataMap, mutable.Map.empty) + } + + /** + * Builds a node from the node data map and adds new nodes to the input nodes map. + */ + private def constructNode( + id: Int, + dataMap: Map[Int, NodeData], + nodes: mutable.Map[Int, Node]): Node = { + if (nodes.contains(id)) { + return nodes(id) + } + val data = dataMap(id) + val node = + if (data.isLeaf) { + Node(data.nodeId, data.predict.toPredict, data.impurity, data.isLeaf) + } else { + val leftNode = constructNode(data.leftNodeId.get, dataMap, nodes) + val rightNode = constructNode(data.rightNodeId.get, dataMap, nodes) + val stats = new InformationGainStats(data.infoGain.get, data.impurity, leftNode.impurity, + rightNode.impurity, leftNode.predict, rightNode.predict) + new Node(data.nodeId, data.predict.toPredict, data.impurity, data.isLeaf, + data.split.map(_.toSplit), Some(leftNode), Some(rightNode), Some(stats)) + } + nodes += node.id -> node + node + } + } + + override def load(sc: SparkContext, path: String): DecisionTreeModel = { + implicit val formats = DefaultFormats + val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path) + val algo = (metadata \ "algo").extract[String] + val numNodes = (metadata \ "numNodes").extract[Int] + val classNameV1_0 = SaveLoadV1_0.thisClassName + (loadedClassName, version) match { + case (className, "1.0") if className == classNameV1_0 => + SaveLoadV1_0.load(sc, path, algo, numNodes) + case _ => throw new Exception( + s"DecisionTreeModel.load did not recognize model with (className, format version):" + + s"($loadedClassName, $version). Supported:\n" + + s" ($classNameV1_0, 1.0)") + } + } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala index 9a50ecb550c38..80990aa9a603f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala @@ -49,7 +49,9 @@ class InformationGainStats( gain == other.gain && impurity == other.impurity && leftImpurity == other.leftImpurity && - rightImpurity == other.rightImpurity + rightImpurity == other.rightImpurity && + leftPredict == other.leftPredict && + rightPredict == other.rightPredict } case _ => false } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala index 2179da8dbe03e..d961081d185e9 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala @@ -166,6 +166,11 @@ class Node ( } } + /** Returns an iterator that traverses (DFS, left to right) the subtree of this node. */ + private[tree] def subtreeIterator: Iterator[Node] = { + Iterator.single(this) ++ leftNode.map(_.subtreeIterator).getOrElse(Iterator.empty) ++ + rightNode.map(_.subtreeIterator).getOrElse(Iterator.empty) + } } private[tree] object Node { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala index 004838ee5ba0e..ad4c0dbbfb3e5 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala @@ -32,4 +32,11 @@ class Predict( override def toString = { "predict = %f, prob = %f".format(predict, prob) } + + override def equals(other: Any): Boolean = { + other match { + case p: Predict => predict == p.predict && prob == p.prob + case _ => false + } + } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala index 22997110de8dd..4897906aea5b3 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala @@ -20,13 +20,20 @@ package org.apache.spark.mllib.tree.model import scala.collection.mutable import com.github.fommil.netlib.BLAS.{getInstance => blas} +import org.json4s._ +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods._ +import org.apache.spark.SparkContext import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.tree.configuration.Algo import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.EnsembleCombiningStrategy._ +import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd.RDD +import org.apache.spark.sql.SQLContext /** * :: Experimental :: @@ -38,9 +45,42 @@ import org.apache.spark.rdd.RDD @Experimental class RandomForestModel(override val algo: Algo, override val trees: Array[DecisionTreeModel]) extends TreeEnsembleModel(algo, trees, Array.fill(trees.size)(1.0), - combiningStrategy = if (algo == Classification) Vote else Average) { + combiningStrategy = if (algo == Classification) Vote else Average) + with Saveable { require(trees.forall(_.algo == algo)) + + override def save(sc: SparkContext, path: String): Unit = { + TreeEnsembleModel.SaveLoadV1_0.save(sc, path, this, + RandomForestModel.SaveLoadV1_0.thisClassName) + } + + override protected def formatVersion: String = TreeEnsembleModel.SaveLoadV1_0.thisFormatVersion +} + +object RandomForestModel extends Loader[RandomForestModel] { + + override def load(sc: SparkContext, path: String): RandomForestModel = { + val (loadedClassName, version, jsonMetadata) = Loader.loadMetadata(sc, path) + val classNameV1_0 = SaveLoadV1_0.thisClassName + (loadedClassName, version) match { + case (className, "1.0") if className == classNameV1_0 => + val metadata = TreeEnsembleModel.SaveLoadV1_0.readMetadata(jsonMetadata) + assert(metadata.treeWeights.forall(_ == 1.0)) + val trees = + TreeEnsembleModel.SaveLoadV1_0.loadTrees(sc, path, metadata.treeAlgo) + new RandomForestModel(Algo.fromString(metadata.algo), trees) + case _ => throw new Exception(s"RandomForestModel.load did not recognize model" + + s" with (className, format version): ($loadedClassName, $version). Supported:\n" + + s" ($classNameV1_0, 1.0)") + } + } + + private object SaveLoadV1_0 { + // Hard-code class name string in case it changes in the future + def thisClassName = "org.apache.spark.mllib.tree.model.RandomForestModel" + } + } /** @@ -56,9 +96,42 @@ class GradientBoostedTreesModel( override val algo: Algo, override val trees: Array[DecisionTreeModel], override val treeWeights: Array[Double]) - extends TreeEnsembleModel(algo, trees, treeWeights, combiningStrategy = Sum) { + extends TreeEnsembleModel(algo, trees, treeWeights, combiningStrategy = Sum) + with Saveable { require(trees.size == treeWeights.size) + + override def save(sc: SparkContext, path: String): Unit = { + TreeEnsembleModel.SaveLoadV1_0.save(sc, path, this, + GradientBoostedTreesModel.SaveLoadV1_0.thisClassName) + } + + override protected def formatVersion: String = TreeEnsembleModel.SaveLoadV1_0.thisFormatVersion +} + +object GradientBoostedTreesModel extends Loader[GradientBoostedTreesModel] { + + override def load(sc: SparkContext, path: String): GradientBoostedTreesModel = { + val (loadedClassName, version, jsonMetadata) = Loader.loadMetadata(sc, path) + val classNameV1_0 = SaveLoadV1_0.thisClassName + (loadedClassName, version) match { + case (className, "1.0") if className == classNameV1_0 => + val metadata = TreeEnsembleModel.SaveLoadV1_0.readMetadata(jsonMetadata) + assert(metadata.combiningStrategy == Sum.toString) + val trees = + TreeEnsembleModel.SaveLoadV1_0.loadTrees(sc, path, metadata.treeAlgo) + new GradientBoostedTreesModel(Algo.fromString(metadata.algo), trees, metadata.treeWeights) + case _ => throw new Exception(s"GradientBoostedTreesModel.load did not recognize model" + + s" with (className, format version): ($loadedClassName, $version). Supported:\n" + + s" ($classNameV1_0, 1.0)") + } + } + + private object SaveLoadV1_0 { + // Hard-code class name string in case it changes in the future + def thisClassName = "org.apache.spark.mllib.tree.model.GradientBoostedTreesModel" + } + } /** @@ -176,3 +249,74 @@ private[tree] sealed class TreeEnsembleModel( */ def totalNumNodes: Int = trees.map(_.numNodes).sum } + +private[tree] object TreeEnsembleModel { + + object SaveLoadV1_0 { + + import org.apache.spark.mllib.tree.model.DecisionTreeModel.SaveLoadV1_0.{NodeData, constructTrees} + + def thisFormatVersion = "1.0" + + case class Metadata( + algo: String, + treeAlgo: String, + combiningStrategy: String, + treeWeights: Array[Double]) + + /** + * Model data for model import/export. + * We have to duplicate NodeData here since Spark SQL does not yet support extracting subfields + * of nested fields; once that is possible, we can use something like: + * case class EnsembleNodeData(treeId: Int, node: NodeData), + * where NodeData is from DecisionTreeModel. + */ + case class EnsembleNodeData(treeId: Int, node: NodeData) + + def save(sc: SparkContext, path: String, model: TreeEnsembleModel, className: String): Unit = { + val sqlContext = new SQLContext(sc) + import sqlContext.implicits._ + + // Create JSON metadata. + implicit val format = DefaultFormats + val ensembleMetadata = Metadata(model.algo.toString, model.trees(0).algo.toString, + model.combiningStrategy.toString, model.treeWeights) + val metadata = compact(render( + ("class" -> className) ~ ("version" -> thisFormatVersion) ~ + ("metadata" -> Extraction.decompose(ensembleMetadata)))) + sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path)) + + // Create Parquet data. + val dataRDD = sc.parallelize(model.trees.zipWithIndex).flatMap { case (tree, treeId) => + tree.topNode.subtreeIterator.toSeq.map(node => NodeData(treeId, node)) + }.toDF() + dataRDD.saveAsParquetFile(Loader.dataPath(path)) + } + + /** + * Read metadata from the loaded JSON metadata. + */ + def readMetadata(metadata: JValue): Metadata = { + implicit val formats = DefaultFormats + (metadata \ "metadata").extract[Metadata] + } + + /** + * Load trees for an ensemble, and return them in order. + * @param path path to load the model from + * @param treeAlgo Algorithm for individual trees (which may differ from the ensemble's + * algorithm). + */ + def loadTrees( + sc: SparkContext, + path: String, + treeAlgo: String): Array[DecisionTreeModel] = { + val datapath = Loader.dataPath(path) + val sqlContext = new SQLContext(sc) + val nodes = sqlContext.parquetFile(datapath).map(NodeData.apply) + val trees = constructTrees(nodes) + trees.map(new DecisionTreeModel(_, Algo.fromString(treeAlgo))) + } + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/DataValidators.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/DataValidators.scala index 45f95482a1def..be335a1aca58a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/DataValidators.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/DataValidators.scala @@ -34,11 +34,27 @@ object DataValidators extends Logging { * * @return True if labels are all zero or one, false otherwise. */ - val binaryLabelValidator: RDD[LabeledPoint] => Boolean = { data => + val binaryLabelValidator: RDD[LabeledPoint] => Boolean = { data => val numInvalid = data.filter(x => x.label != 1.0 && x.label != 0.0).count() if (numInvalid != 0) { logError("Classification labels should be 0 or 1. Found " + numInvalid + " invalid labels") } numInvalid == 0 } + + /** + * Function to check if labels used for k class multi-label classification are + * in the range of {0, 1, ..., k - 1}. + * + * @return True if labels are all in the range of {0, 1, ..., k-1}, false otherwise. + */ + def multiLabelValidator(k: Int): RDD[LabeledPoint] => Boolean = { data => + val numInvalid = data.filter(x => + x.label - x.label.toInt != 0.0 || x.label < 0 || x.label > k - 1).count() + if (numInvalid != 0) { + logError("Classification labels should be in {0 to " + (k - 1) + "}. " + + "Found " + numInvalid + " invalid labels") + } + numInvalid == 0 + } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala index 69299c219878c..97f54aa62d31c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala @@ -62,7 +62,7 @@ object LinearDataGenerator { * @param nPoints Number of points in sample. * @param seed Random seed * @param eps Epsilon scaling factor. - * @return + * @return Seq of input. */ def generateLinearInput( intercept: Double, diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/NumericParser.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/NumericParser.scala index f7cba6c6cb628..308f7f3578e21 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/NumericParser.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/NumericParser.scala @@ -19,7 +19,7 @@ package org.apache.spark.mllib.util import java.util.StringTokenizer -import scala.collection.mutable.{ArrayBuffer, ListBuffer} +import scala.collection.mutable.{ArrayBuilder, ListBuffer} import org.apache.spark.SparkException @@ -51,7 +51,7 @@ private[mllib] object NumericParser { } private def parseArray(tokenizer: StringTokenizer): Array[Double] = { - val values = ArrayBuffer.empty[Double] + val values = ArrayBuilder.make[Double] var parsing = true var allowComma = false var token: String = null @@ -67,14 +67,14 @@ private[mllib] object NumericParser { } } else { // expecting a number - values.append(parseDouble(token)) + values += parseDouble(token) allowComma = true } } if (parsing) { throw new SparkException(s"An array must end with ']'.") } - values.toArray + values.result() } private def parseTuple(tokenizer: StringTokenizer): Seq[_] = { @@ -114,7 +114,7 @@ private[mllib] object NumericParser { try { java.lang.Double.parseDouble(s) } catch { - case e: Throwable => + case e: NumberFormatException => throw new SparkException(s"Cannot parse a double from: $s", e) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/modelSaveLoad.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/modelSaveLoad.scala new file mode 100644 index 0000000000000..4458340497f0b --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/modelSaveLoad.scala @@ -0,0 +1,130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.util + +import scala.reflect.runtime.universe.TypeTag + +import org.apache.hadoop.fs.Path +import org.json4s._ +import org.json4s.jackson.JsonMethods._ + +import org.apache.spark.SparkContext +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.types.{DataType, StructField, StructType} + +/** + * :: DeveloperApi :: + * + * Trait for models and transformers which may be saved as files. + * This should be inherited by the class which implements model instances. + */ +@DeveloperApi +trait Saveable { + + /** + * Save this model to the given path. + * + * This saves: + * - human-readable (JSON) model metadata to path/metadata/ + * - Parquet formatted data to path/data/ + * + * The model may be loaded using [[Loader.load]]. + * + * @param sc Spark context used to save model data. + * @param path Path specifying the directory in which to save this model. + * This directory and any intermediate directory will be created if needed. + */ + def save(sc: SparkContext, path: String): Unit + + /** Current version of model save/load format. */ + protected def formatVersion: String + +} + +/** + * :: DeveloperApi :: + * + * Trait for classes which can load models and transformers from files. + * This should be inherited by an object paired with the model class. + */ +@DeveloperApi +trait Loader[M <: Saveable] { + + /** + * Load a model from the given path. + * + * The model should have been saved by [[Saveable.save]]. + * + * @param sc Spark context used for loading model files. + * @param path Path specifying the directory to which the model was saved. + * @return Model instance + */ + def load(sc: SparkContext, path: String): M + +} + +/** + * Helper methods for loading models from files. + */ +private[mllib] object Loader { + + /** Returns URI for path/data using the Hadoop filesystem */ + def dataPath(path: String): String = new Path(path, "data").toUri.toString + + /** Returns URI for path/metadata using the Hadoop filesystem */ + def metadataPath(path: String): String = new Path(path, "metadata").toUri.toString + + /** + * Check the schema of loaded model data. + * + * This checks every field in the expected schema to make sure that a field with the same + * name and DataType appears in the loaded schema. Note that this does NOT check metadata + * or containsNull. + * + * @param loadedSchema Schema for model data loaded from file. + * @tparam Data Expected data type from which an expected schema can be derived. + */ + def checkSchema[Data: TypeTag](loadedSchema: StructType): Unit = { + // Check schema explicitly since erasure makes it hard to use match-case for checking. + val expectedFields: Array[StructField] = + ScalaReflection.schemaFor[Data].dataType.asInstanceOf[StructType].fields + val loadedFields: Map[String, DataType] = + loadedSchema.map(field => field.name -> field.dataType).toMap + expectedFields.foreach { field => + assert(loadedFields.contains(field.name), s"Unable to parse model data." + + s" Expected field with name ${field.name} was missing in loaded schema:" + + s" ${loadedFields.mkString(", ")}") + assert(loadedFields(field.name) == field.dataType, + s"Unable to parse model data. Expected field $field but found field" + + s" with different type: ${loadedFields(field.name)}") + } + } + + /** + * Load metadata from the given path. + * @return (class name, version, metadata) + */ + def loadMetadata(sc: SparkContext, path: String): (String, String, JValue) = { + implicit val formats = DefaultFormats + val metadata = parse(sc.textFile(metadataPath(path)).first()) + val clazz = (metadata \ "class").extract[String] + val version = (metadata \ "version").extract[String] + (clazz, version, metadata) + } +} diff --git a/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java b/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java index 42846677ed285..0a8c9e5954676 100644 --- a/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java @@ -26,10 +26,9 @@ import org.apache.spark.mllib.regression.LabeledPoint; import org.apache.spark.ml.classification.LogisticRegression; import org.apache.spark.ml.feature.StandardScaler; -import org.apache.spark.sql.api.java.JavaSQLContext; -import org.apache.spark.sql.api.java.JavaSchemaRDD; -import static org.apache.spark.mllib.classification.LogisticRegressionSuite - .generateLogisticInputAsList; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.SQLContext; +import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInputAsList; /** * Test Pipeline construction and fitting in Java. @@ -37,16 +36,16 @@ public class JavaPipelineSuite { private transient JavaSparkContext jsc; - private transient JavaSQLContext jsql; - private transient JavaSchemaRDD dataset; + private transient SQLContext jsql; + private transient DataFrame dataset; @Before public void setUp() { jsc = new JavaSparkContext("local", "JavaPipelineSuite"); - jsql = new JavaSQLContext(jsc); + jsql = new SQLContext(jsc); JavaRDD points = jsc.parallelize(generateLogisticInputAsList(1.0, 1.0, 100, 42), 2); - dataset = jsql.applySchema(points, LabeledPoint.class); + dataset = jsql.createDataFrame(points, LabeledPoint.class); } @After @@ -66,7 +65,7 @@ public void pipeline() { .setStages(new PipelineStage[] {scaler, lr}); PipelineModel model = pipeline.fit(dataset); model.transform(dataset).registerTempTable("prediction"); - JavaSchemaRDD predictions = jsql.sql("SELECT label, score, prediction FROM prediction"); - predictions.collect(); + DataFrame predictions = jsql.sql("SELECT label, probability, prediction FROM prediction"); + predictions.collectAsList(); } } diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java index 76eb7f00329f2..3f8e59de0f05c 100644 --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java @@ -18,31 +18,40 @@ package org.apache.spark.ml.classification; import java.io.Serializable; +import java.lang.Math; import java.util.List; import org.junit.After; import org.junit.Before; import org.junit.Test; +import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; +import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInputAsList; +import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.sql.api.java.JavaSQLContext; -import org.apache.spark.sql.api.java.JavaSchemaRDD; -import static org.apache.spark.mllib.classification.LogisticRegressionSuite - .generateLogisticInputAsList; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.Row; + public class JavaLogisticRegressionSuite implements Serializable { private transient JavaSparkContext jsc; - private transient JavaSQLContext jsql; - private transient JavaSchemaRDD dataset; + private transient SQLContext jsql; + private transient DataFrame dataset; + + private transient JavaRDD datasetRDD; + private double eps = 1e-5; @Before public void setUp() { jsc = new JavaSparkContext("local", "JavaLogisticRegressionSuite"); - jsql = new JavaSQLContext(jsc); + jsql = new SQLContext(jsc); List points = generateLogisticInputAsList(1.0, 1.0, 100, 42); - dataset = jsql.applySchema(jsc.parallelize(points, 2), LabeledPoint.class); + datasetRDD = jsc.parallelize(points, 2); + dataset = jsql.createDataFrame(datasetRDD, LabeledPoint.class); + dataset.registerTempTable("dataset"); } @After @@ -52,29 +61,88 @@ public void tearDown() { } @Test - public void logisticRegression() { + public void logisticRegressionDefaultParams() { LogisticRegression lr = new LogisticRegression(); + assert(lr.getLabelCol().equals("label")); LogisticRegressionModel model = lr.fit(dataset); model.transform(dataset).registerTempTable("prediction"); - JavaSchemaRDD predictions = jsql.sql("SELECT label, score, prediction FROM prediction"); - predictions.collect(); + DataFrame predictions = jsql.sql("SELECT label, probability, prediction FROM prediction"); + predictions.collectAsList(); + // Check defaults + assert(model.getThreshold() == 0.5); + assert(model.getFeaturesCol().equals("features")); + assert(model.getPredictionCol().equals("prediction")); + assert(model.getProbabilityCol().equals("probability")); } @Test public void logisticRegressionWithSetters() { + // Set params, train, and check as many params as we can. LogisticRegression lr = new LogisticRegression() .setMaxIter(10) - .setRegParam(1.0); + .setRegParam(1.0) + .setThreshold(0.6) + .setProbabilityCol("myProbability"); LogisticRegressionModel model = lr.fit(dataset); - model.transform(dataset, model.threshold().w(0.8)) // overwrite threshold - .registerTempTable("prediction"); - JavaSchemaRDD predictions = jsql.sql("SELECT label, score, prediction FROM prediction"); - predictions.collect(); + assert(model.fittingParamMap().apply(lr.maxIter()).equals(10)); + assert(model.fittingParamMap().apply(lr.regParam()).equals(1.0)); + assert(model.fittingParamMap().apply(lr.threshold()).equals(0.6)); + assert(model.getThreshold() == 0.6); + + // Modify model params, and check that the params worked. + model.setThreshold(1.0); + model.transform(dataset).registerTempTable("predAllZero"); + DataFrame predAllZero = jsql.sql("SELECT prediction, myProbability FROM predAllZero"); + for (Row r: predAllZero.collectAsList()) { + assert(r.getDouble(0) == 0.0); + } + // Call transform with params, and check that the params worked. + model.transform(dataset, model.threshold().w(0.0), model.probabilityCol().w("myProb")) + .registerTempTable("predNotAllZero"); + DataFrame predNotAllZero = jsql.sql("SELECT prediction, myProb FROM predNotAllZero"); + boolean foundNonZero = false; + for (Row r: predNotAllZero.collectAsList()) { + if (r.getDouble(0) != 0.0) foundNonZero = true; + } + assert(foundNonZero); + + // Call fit() with new params, and check as many params as we can. + LogisticRegressionModel model2 = lr.fit(dataset, lr.maxIter().w(5), lr.regParam().w(0.1), + lr.threshold().w(0.4), lr.probabilityCol().w("theProb")); + assert(model2.fittingParamMap().apply(lr.maxIter()).equals(5)); + assert(model2.fittingParamMap().apply(lr.regParam()).equals(0.1)); + assert(model2.fittingParamMap().apply(lr.threshold()).equals(0.4)); + assert(model2.getThreshold() == 0.4); + assert(model2.getProbabilityCol().equals("theProb")); } + @SuppressWarnings("unchecked") @Test - public void logisticRegressionFitWithVarargs() { + public void logisticRegressionPredictorClassifierMethods() { LogisticRegression lr = new LogisticRegression(); - lr.fit(dataset, lr.maxIter().w(10), lr.regParam().w(1.0)); + LogisticRegressionModel model = lr.fit(dataset); + assert(model.numClasses() == 2); + + model.transform(dataset).registerTempTable("transformed"); + DataFrame trans1 = jsql.sql("SELECT rawPrediction, probability FROM transformed"); + for (Row row: trans1.collect()) { + Vector raw = (Vector)row.get(0); + Vector prob = (Vector)row.get(1); + assert(raw.size() == 2); + assert(prob.size() == 2); + double probFromRaw1 = 1.0 / (1.0 + Math.exp(-raw.apply(1))); + assert(Math.abs(prob.apply(1) - probFromRaw1) < eps); + assert(Math.abs(prob.apply(0) - (1.0 - probFromRaw1)) < eps); + } + + DataFrame trans2 = jsql.sql("SELECT prediction, probability FROM transformed"); + for (Row row: trans2.collect()) { + double pred = row.getDouble(0); + Vector prob = (Vector)row.get(1); + double probOfPred = prob.apply((int)pred); + for (int i = 0; i < prob.size(); ++i) { + assert(probOfPred >= prob.apply(i)); + } + } } } diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaStreamingLogisticRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaStreamingLogisticRegressionSuite.java new file mode 100644 index 0000000000000..640d2ec55e4e7 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaStreamingLogisticRegressionSuite.java @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.classification; + +import java.io.Serializable; +import java.util.List; + +import scala.Tuple2; + +import com.google.common.collect.Lists; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.SparkConf; +import org.apache.spark.mllib.classification.StreamingLogisticRegressionWithSGD; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.streaming.Duration; +import org.apache.spark.streaming.api.java.JavaDStream; +import org.apache.spark.streaming.api.java.JavaPairDStream; +import org.apache.spark.streaming.api.java.JavaStreamingContext; +import static org.apache.spark.streaming.JavaTestUtils.*; + +public class JavaStreamingLogisticRegressionSuite implements Serializable { + + protected transient JavaStreamingContext ssc; + + @Before + public void setUp() { + SparkConf conf = new SparkConf() + .setMaster("local[2]") + .setAppName("test") + .set("spark.streaming.clock", "org.apache.spark.util.ManualClock"); + ssc = new JavaStreamingContext(conf, new Duration(1000)); + ssc.checkpoint("checkpoint"); + } + + @After + public void tearDown() { + ssc.stop(); + ssc = null; + } + + @Test + @SuppressWarnings("unchecked") + public void javaAPI() { + List trainingBatch = Lists.newArrayList( + new LabeledPoint(1.0, Vectors.dense(1.0)), + new LabeledPoint(0.0, Vectors.dense(0.0))); + JavaDStream training = + attachTestInputStream(ssc, Lists.newArrayList(trainingBatch, trainingBatch), 2); + List> testBatch = Lists.newArrayList( + new Tuple2(10, Vectors.dense(1.0)), + new Tuple2(11, Vectors.dense(0.0))); + JavaPairDStream test = JavaPairDStream.fromJavaDStream( + attachTestInputStream(ssc, Lists.newArrayList(testBatch, testBatch), 2)); + StreamingLogisticRegressionWithSGD slr = new StreamingLogisticRegressionWithSGD() + .setNumIterations(2) + .setInitialWeights(Vectors.dense(0.0)); + slr.trainOn(training); + JavaPairDStream prediction = slr.predictOnValues(test); + attachTestOutputStream(prediction.count()); + runStreams(ssc, 2, 2); + } +} diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java new file mode 100644 index 0000000000000..0cc36c8d56d70 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.regression; + +import java.io.Serializable; +import java.util.List; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import static org.apache.spark.mllib.classification.LogisticRegressionSuite + .generateLogisticInputAsList; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.SQLContext; + + +public class JavaLinearRegressionSuite implements Serializable { + + private transient JavaSparkContext jsc; + private transient SQLContext jsql; + private transient DataFrame dataset; + private transient JavaRDD datasetRDD; + + @Before + public void setUp() { + jsc = new JavaSparkContext("local", "JavaLinearRegressionSuite"); + jsql = new SQLContext(jsc); + List points = generateLogisticInputAsList(1.0, 1.0, 100, 42); + datasetRDD = jsc.parallelize(points, 2); + dataset = jsql.createDataFrame(datasetRDD, LabeledPoint.class); + dataset.registerTempTable("dataset"); + } + + @After + public void tearDown() { + jsc.stop(); + jsc = null; + } + + @Test + public void linearRegressionDefaultParams() { + LinearRegression lr = new LinearRegression(); + assert(lr.getLabelCol().equals("label")); + LinearRegressionModel model = lr.fit(dataset); + model.transform(dataset).registerTempTable("prediction"); + DataFrame predictions = jsql.sql("SELECT label, prediction FROM prediction"); + predictions.collect(); + // Check defaults + assert(model.getFeaturesCol().equals("features")); + assert(model.getPredictionCol().equals("prediction")); + } + + @Test + public void linearRegressionWithSetters() { + // Set params, train, and check as many params as we can. + LinearRegression lr = new LinearRegression() + .setMaxIter(10) + .setRegParam(1.0); + LinearRegressionModel model = lr.fit(dataset); + assert(model.fittingParamMap().apply(lr.maxIter()).equals(10)); + assert(model.fittingParamMap().apply(lr.regParam()).equals(1.0)); + + // Call fit() with new params, and check as many params as we can. + LinearRegressionModel model2 = + lr.fit(dataset, lr.maxIter().w(5), lr.regParam().w(0.1), lr.predictionCol().w("thePred")); + assert(model2.fittingParamMap().apply(lr.maxIter()).equals(5)); + assert(model2.fittingParamMap().apply(lr.regParam()).equals(0.1)); + assert(model2.getPredictionCol().equals("thePred")); + } +} diff --git a/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java b/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java index a266ebd2071a1..0bb6b489f2757 100644 --- a/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java @@ -30,23 +30,22 @@ import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator; import org.apache.spark.ml.param.ParamMap; import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.sql.api.java.JavaSQLContext; -import org.apache.spark.sql.api.java.JavaSchemaRDD; -import static org.apache.spark.mllib.classification.LogisticRegressionSuite - .generateLogisticInputAsList; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.SQLContext; +import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInputAsList; public class JavaCrossValidatorSuite implements Serializable { private transient JavaSparkContext jsc; - private transient JavaSQLContext jsql; - private transient JavaSchemaRDD dataset; + private transient SQLContext jsql; + private transient DataFrame dataset; @Before public void setUp() { jsc = new JavaSparkContext("local", "JavaCrossValidatorSuite"); - jsql = new JavaSQLContext(jsc); + jsql = new SQLContext(jsc); List points = generateLogisticInputAsList(1.0, 1.0, 100, 42); - dataset = jsql.applySchema(jsc.parallelize(points, 2), LabeledPoint.class); + dataset = jsql.createDataFrame(jsc.parallelize(points, 2), LabeledPoint.class); } @After diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java new file mode 100644 index 0000000000000..dc10aa67c7c1f --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.clustering; + +import java.io.Serializable; +import java.util.ArrayList; + +import org.apache.spark.api.java.JavaRDD; +import scala.Tuple2; + +import org.junit.After; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertArrayEquals; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaPairRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.linalg.Matrix; +import org.apache.spark.mllib.linalg.Vector; + + +public class JavaLDASuite implements Serializable { + private transient JavaSparkContext sc; + + @Before + public void setUp() { + sc = new JavaSparkContext("local", "JavaLDA"); + ArrayList> tinyCorpus = new ArrayList>(); + for (int i = 0; i < LDASuite$.MODULE$.tinyCorpus().length; i++) { + tinyCorpus.add(new Tuple2((Long)LDASuite$.MODULE$.tinyCorpus()[i]._1(), + LDASuite$.MODULE$.tinyCorpus()[i]._2())); + } + JavaRDD> tmpCorpus = sc.parallelize(tinyCorpus, 2); + corpus = JavaPairRDD.fromJavaRDD(tmpCorpus); + } + + @After + public void tearDown() { + sc.stop(); + sc = null; + } + + @Test + public void localLDAModel() { + LocalLDAModel model = new LocalLDAModel(LDASuite$.MODULE$.tinyTopics()); + + // Check: basic parameters + assertEquals(model.k(), tinyK); + assertEquals(model.vocabSize(), tinyVocabSize); + assertEquals(model.topicsMatrix(), tinyTopics); + + // Check: describeTopics() with all terms + Tuple2[] fullTopicSummary = model.describeTopics(); + assertEquals(fullTopicSummary.length, tinyK); + for (int i = 0; i < fullTopicSummary.length; i++) { + assertArrayEquals(fullTopicSummary[i]._1(), tinyTopicDescription[i]._1()); + assertArrayEquals(fullTopicSummary[i]._2(), tinyTopicDescription[i]._2(), 1e-5); + } + } + + @Test + public void distributedLDAModel() { + int k = 3; + double topicSmoothing = 1.2; + double termSmoothing = 1.2; + + // Train a model + LDA lda = new LDA(); + lda.setK(k) + .setDocConcentration(topicSmoothing) + .setTopicConcentration(termSmoothing) + .setMaxIterations(5) + .setSeed(12345); + + DistributedLDAModel model = lda.run(corpus); + + // Check: basic parameters + LocalLDAModel localModel = model.toLocal(); + assertEquals(model.k(), k); + assertEquals(localModel.k(), k); + assertEquals(model.vocabSize(), tinyVocabSize); + assertEquals(localModel.vocabSize(), tinyVocabSize); + assertEquals(model.topicsMatrix(), localModel.topicsMatrix()); + + // Check: topic summaries + Tuple2[] roundedTopicSummary = model.describeTopics(); + assertEquals(roundedTopicSummary.length, k); + Tuple2[] roundedLocalTopicSummary = localModel.describeTopics(); + assertEquals(roundedLocalTopicSummary.length, k); + + // Check: log probabilities + assert(model.logLikelihood() < 0.0); + assert(model.logPrior() < 0.0); + } + + private static int tinyK = LDASuite$.MODULE$.tinyK(); + private static int tinyVocabSize = LDASuite$.MODULE$.tinyVocabSize(); + private static Matrix tinyTopics = LDASuite$.MODULE$.tinyTopics(); + private static Tuple2[] tinyTopicDescription = + LDASuite$.MODULE$.tinyTopicDescription(); + JavaPairRDD corpus; + +} diff --git a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java new file mode 100644 index 0000000000000..bd0edf2b9ea62 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.fpm; + +import java.io.Serializable; +import java.util.ArrayList; +import java.util.List; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import com.google.common.collect.Lists; +import static org.junit.Assert.*; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.fpm.FPGrowth.FreqItemset; + +public class JavaFPGrowthSuite implements Serializable { + private transient JavaSparkContext sc; + + @Before + public void setUp() { + sc = new JavaSparkContext("local", "JavaFPGrowth"); + } + + @After + public void tearDown() { + sc.stop(); + sc = null; + } + + @Test + public void runFPGrowth() { + + @SuppressWarnings("unchecked") + JavaRDD> rdd = sc.parallelize(Lists.newArrayList( + Lists.newArrayList("r z h k p".split(" ")), + Lists.newArrayList("z y x w v u t s".split(" ")), + Lists.newArrayList("s x o n r".split(" ")), + Lists.newArrayList("x z y m t s q e".split(" ")), + Lists.newArrayList("z".split(" ")), + Lists.newArrayList("x z y r q t p".split(" "))), 2); + + FPGrowthModel model = new FPGrowth() + .setMinSupport(0.5) + .setNumPartitions(2) + .run(rdd); + + List> freqItemsets = model.freqItemsets().toJavaRDD().collect(); + assertEquals(18, freqItemsets.size()); + + for (FreqItemset itemset: freqItemsets) { + // Test return types. + List items = itemset.javaItems(); + long freq = itemset.freq(); + } + } +} diff --git a/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaMatricesSuite.java b/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaMatricesSuite.java index 704d484d0b585..3349c5022423a 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaMatricesSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaMatricesSuite.java @@ -71,8 +71,8 @@ public void diagonalMatrixConstruction() { Matrix sm = Matrices.diag(sv); DenseMatrix d = DenseMatrix.diag(v); DenseMatrix sd = DenseMatrix.diag(sv); - SparseMatrix s = SparseMatrix.diag(v); - SparseMatrix ss = SparseMatrix.diag(sv); + SparseMatrix s = SparseMatrix.spdiag(v); + SparseMatrix ss = SparseMatrix.spdiag(sv); assertArrayEquals(m.toArray(), sm.toArray(), 0.0); assertArrayEquals(d.toArray(), sm.toArray(), 0.0); diff --git a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaIsotonicRegressionSuite.java b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaIsotonicRegressionSuite.java new file mode 100644 index 0000000000000..d38fc91ace3cf --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaIsotonicRegressionSuite.java @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.regression; + +import java.io.Serializable; +import java.util.List; + +import scala.Tuple3; + +import com.google.common.collect.Lists; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaDoubleRDD; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; + +public class JavaIsotonicRegressionSuite implements Serializable { + private transient JavaSparkContext sc; + + private List> generateIsotonicInput(double[] labels) { + List> input = Lists.newArrayList(); + + for (int i = 1; i <= labels.length; i++) { + input.add(new Tuple3(labels[i-1], (double) i, 1d)); + } + + return input; + } + + private IsotonicRegressionModel runIsotonicRegression(double[] labels) { + JavaRDD> trainRDD = + sc.parallelize(generateIsotonicInput(labels), 2).cache(); + + return new IsotonicRegression().run(trainRDD); + } + + @Before + public void setUp() { + sc = new JavaSparkContext("local", "JavaLinearRegressionSuite"); + } + + @After + public void tearDown() { + sc.stop(); + sc = null; + } + + @Test + public void testIsotonicRegressionJavaRDD() { + IsotonicRegressionModel model = + runIsotonicRegression(new double[]{1, 2, 3, 3, 1, 6, 7, 8, 11, 9, 10, 12}); + + Assert.assertArrayEquals( + new double[] {1, 2, 7d/3, 7d/3, 6, 7, 8, 10, 10, 12}, model.predictions(), 1e-14); + } + + @Test + public void testIsotonicRegressionPredictionsJavaRDD() { + IsotonicRegressionModel model = + runIsotonicRegression(new double[]{1, 2, 3, 3, 1, 6, 7, 8, 11, 9, 10, 12}); + + JavaDoubleRDD testRDD = sc.parallelizeDoubles(Lists.newArrayList(0.0, 1.0, 9.5, 12.0, 13.0)); + List predictions = model.predict(testRDD).collect(); + + Assert.assertTrue(predictions.get(0) == 1d); + Assert.assertTrue(predictions.get(1) == 1d); + Assert.assertTrue(predictions.get(2) == 10d); + Assert.assertTrue(predictions.get(3) == 12d); + Assert.assertTrue(predictions.get(4) == 12d); + } +} diff --git a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaStreamingLinearRegressionSuite.java b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaStreamingLinearRegressionSuite.java new file mode 100644 index 0000000000000..899c4ea607869 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaStreamingLinearRegressionSuite.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.regression; + +import java.io.Serializable; +import java.util.List; + +import scala.Tuple2; + +import com.google.common.collect.Lists; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.SparkConf; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.streaming.Duration; +import org.apache.spark.streaming.api.java.JavaDStream; +import org.apache.spark.streaming.api.java.JavaPairDStream; +import org.apache.spark.streaming.api.java.JavaStreamingContext; +import static org.apache.spark.streaming.JavaTestUtils.*; + +public class JavaStreamingLinearRegressionSuite implements Serializable { + + protected transient JavaStreamingContext ssc; + + @Before + public void setUp() { + SparkConf conf = new SparkConf() + .setMaster("local[2]") + .setAppName("test") + .set("spark.streaming.clock", "org.apache.spark.util.ManualClock"); + ssc = new JavaStreamingContext(conf, new Duration(1000)); + ssc.checkpoint("checkpoint"); + } + + @After + public void tearDown() { + ssc.stop(); + ssc = null; + } + + @Test + @SuppressWarnings("unchecked") + public void javaAPI() { + List trainingBatch = Lists.newArrayList( + new LabeledPoint(1.0, Vectors.dense(1.0)), + new LabeledPoint(0.0, Vectors.dense(0.0))); + JavaDStream training = + attachTestInputStream(ssc, Lists.newArrayList(trainingBatch, trainingBatch), 2); + List> testBatch = Lists.newArrayList( + new Tuple2(10, Vectors.dense(1.0)), + new Tuple2(11, Vectors.dense(0.0))); + JavaPairDStream test = JavaPairDStream.fromJavaDStream( + attachTestInputStream(ssc, Lists.newArrayList(testBatch, testBatch), 2)); + StreamingLinearRegressionWithSGD slr = new StreamingLinearRegressionWithSGD() + .setNumIterations(2) + .setInitialWeights(Vectors.dense(0.0)); + slr.trainOn(training); + JavaPairDStream prediction = slr.predictOnValues(test); + attachTestOutputStream(prediction.count()); + runStreams(ssc, 2, 2); + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala index 4515084bc7ae9..2f175fb117941 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala @@ -23,7 +23,7 @@ import org.scalatest.FunSuite import org.scalatest.mock.MockitoSugar.mock import org.apache.spark.ml.param.ParamMap -import org.apache.spark.sql.SchemaRDD +import org.apache.spark.sql.DataFrame class PipelineSuite extends FunSuite { @@ -36,11 +36,11 @@ class PipelineSuite extends FunSuite { val estimator2 = mock[Estimator[MyModel]] val model2 = mock[MyModel] val transformer3 = mock[Transformer] - val dataset0 = mock[SchemaRDD] - val dataset1 = mock[SchemaRDD] - val dataset2 = mock[SchemaRDD] - val dataset3 = mock[SchemaRDD] - val dataset4 = mock[SchemaRDD] + val dataset0 = mock[DataFrame] + val dataset1 = mock[DataFrame] + val dataset2 = mock[DataFrame] + val dataset3 = mock[DataFrame] + val dataset4 = mock[DataFrame] when(estimator0.fit(meq(dataset0), any[ParamMap]())).thenReturn(model0) when(model0.transform(meq(dataset0), any[ParamMap]())).thenReturn(dataset1) @@ -74,7 +74,7 @@ class PipelineSuite extends FunSuite { val estimator = mock[Estimator[MyModel]] val pipeline = new Pipeline() .setStages(Array(estimator, estimator)) - val dataset = mock[SchemaRDD] + val dataset = mock[DataFrame] intercept[IllegalArgumentException] { pipeline.fit(dataset) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index e8030fef55b1d..b3d1bfcfbee0f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -20,50 +20,108 @@ package org.apache.spark.ml.classification import org.scalatest.FunSuite import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput +import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.{SQLContext, SchemaRDD} +import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.sql.{DataFrame, Row, SQLContext} + class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext { @transient var sqlContext: SQLContext = _ - @transient var dataset: SchemaRDD = _ + @transient var dataset: DataFrame = _ + private val eps: Double = 1e-5 override def beforeAll(): Unit = { super.beforeAll() sqlContext = new SQLContext(sc) - dataset = sqlContext.createSchemaRDD( - sc.parallelize(generateLogisticInput(1.0, 1.0, 100, 42), 2)) + dataset = sqlContext.createDataFrame( + sc.parallelize(generateLogisticInput(1.0, 1.0, nPoints = 100, seed = 42), 2)) } - test("logistic regression") { - val sqlContext = this.sqlContext - import sqlContext._ + test("logistic regression: default params") { val lr = new LogisticRegression + assert(lr.getLabelCol == "label") + assert(lr.getFeaturesCol == "features") + assert(lr.getPredictionCol == "prediction") + assert(lr.getRawPredictionCol == "rawPrediction") + assert(lr.getProbabilityCol == "probability") val model = lr.fit(dataset) model.transform(dataset) - .select('label, 'prediction) + .select("label", "probability", "prediction", "rawPrediction") .collect() + assert(model.getThreshold === 0.5) + assert(model.getFeaturesCol == "features") + assert(model.getPredictionCol == "prediction") + assert(model.getRawPredictionCol == "rawPrediction") + assert(model.getProbabilityCol == "probability") } test("logistic regression with setters") { - val sqlContext = this.sqlContext - import sqlContext._ + // Set params, train, and check as many params as we can. val lr = new LogisticRegression() .setMaxIter(10) .setRegParam(1.0) + .setThreshold(0.6) + .setProbabilityCol("myProbability") val model = lr.fit(dataset) - model.transform(dataset, model.threshold -> 0.8) // overwrite threshold - .select('label, 'score, 'prediction) + assert(model.fittingParamMap.get(lr.maxIter) === Some(10)) + assert(model.fittingParamMap.get(lr.regParam) === Some(1.0)) + assert(model.fittingParamMap.get(lr.threshold) === Some(0.6)) + assert(model.getThreshold === 0.6) + + // Modify model params, and check that the params worked. + model.setThreshold(1.0) + val predAllZero = model.transform(dataset) + .select("prediction", "myProbability") .collect() + .map { case Row(pred: Double, prob: Vector) => pred } + assert(predAllZero.forall(_ === 0), + s"With threshold=1.0, expected predictions to be all 0, but only" + + s" ${predAllZero.count(_ === 0)} of ${dataset.count()} were 0.") + // Call transform with params, and check that the params worked. + val predNotAllZero = + model.transform(dataset, model.threshold -> 0.0, model.probabilityCol -> "myProb") + .select("prediction", "myProb") + .collect() + .map { case Row(pred: Double, prob: Vector) => pred } + assert(predNotAllZero.exists(_ !== 0.0)) + + // Call fit() with new params, and check as many params as we can. + val model2 = lr.fit(dataset, lr.maxIter -> 5, lr.regParam -> 0.1, lr.threshold -> 0.4, + lr.probabilityCol -> "theProb") + assert(model2.fittingParamMap.get(lr.maxIter).get === 5) + assert(model2.fittingParamMap.get(lr.regParam).get === 0.1) + assert(model2.fittingParamMap.get(lr.threshold).get === 0.4) + assert(model2.getThreshold === 0.4) + assert(model2.getProbabilityCol == "theProb") } - test("logistic regression fit and transform with varargs") { + test("logistic regression: Predictor, Classifier methods") { val sqlContext = this.sqlContext - import sqlContext._ val lr = new LogisticRegression - val model = lr.fit(dataset, lr.maxIter -> 10, lr.regParam -> 1.0) - model.transform(dataset, model.threshold -> 0.8, model.scoreCol -> "probability") - .select('label, 'probability, 'prediction) - .collect() + + val model = lr.fit(dataset) + assert(model.numClasses === 2) + + val threshold = model.getThreshold + val results = model.transform(dataset) + + // Compare rawPrediction with probability + results.select("rawPrediction", "probability").collect().map { + case Row(raw: Vector, prob: Vector) => + assert(raw.size === 2) + assert(prob.size === 2) + val probFromRaw1 = 1.0 / (1.0 + math.exp(-raw(1))) + assert(prob(1) ~== probFromRaw1 relTol eps) + assert(prob(0) ~== 1.0 - probFromRaw1 relTol eps) + } + + // Compare prediction with probability + results.select("prediction", "probability").collect().map { + case Row(pred: Double, prob: Vector) => + val predFromProb = prob.toArray.zipWithIndex.maxBy(_._1)._2 + assert(pred == predFromProb) + } } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala new file mode 100644 index 0000000000000..bb86bafc0eb0a --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala @@ -0,0 +1,488 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.recommendation + +import java.util.Random + +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer + +import com.github.fommil.netlib.BLAS.{getInstance => blas} +import org.scalatest.FunSuite + +import org.apache.spark.{Logging, SparkException} +import org.apache.spark.ml.recommendation.ALS._ +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{Row, SQLContext} + +class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging { + + private var sqlContext: SQLContext = _ + + override def beforeAll(): Unit = { + super.beforeAll() + sqlContext = new SQLContext(sc) + } + + test("LocalIndexEncoder") { + val random = new Random + for (numBlocks <- Seq(1, 2, 5, 10, 20, 50, 100)) { + val encoder = new LocalIndexEncoder(numBlocks) + val maxLocalIndex = Int.MaxValue / numBlocks + val tests = Seq.fill(5)((random.nextInt(numBlocks), random.nextInt(maxLocalIndex))) ++ + Seq((0, 0), (numBlocks - 1, maxLocalIndex)) + tests.foreach { case (blockId, localIndex) => + val err = s"Failed with numBlocks=$numBlocks, blockId=$blockId, and localIndex=$localIndex." + val encoded = encoder.encode(blockId, localIndex) + assert(encoder.blockId(encoded) === blockId, err) + assert(encoder.localIndex(encoded) === localIndex, err) + } + } + } + + test("normal equation construction with explict feedback") { + val k = 2 + val ne0 = new NormalEquation(k) + .add(Array(1.0f, 2.0f), 3.0f) + .add(Array(4.0f, 5.0f), 6.0f) + assert(ne0.k === k) + assert(ne0.triK === k * (k + 1) / 2) + assert(ne0.n === 2) + // NumPy code that computes the expected values: + // A = np.matrix("1 2; 4 5") + // b = np.matrix("3; 6") + // ata = A.transpose() * A + // atb = A.transpose() * b + assert(Vectors.dense(ne0.ata) ~== Vectors.dense(17.0, 22.0, 29.0) relTol 1e-8) + assert(Vectors.dense(ne0.atb) ~== Vectors.dense(27.0, 36.0) relTol 1e-8) + + val ne1 = new NormalEquation(2) + .add(Array(7.0f, 8.0f), 9.0f) + ne0.merge(ne1) + assert(ne0.n === 3) + // NumPy code that computes the expected values: + // A = np.matrix("1 2; 4 5; 7 8") + // b = np.matrix("3; 6; 9") + // ata = A.transpose() * A + // atb = A.transpose() * b + assert(Vectors.dense(ne0.ata) ~== Vectors.dense(66.0, 78.0, 93.0) relTol 1e-8) + assert(Vectors.dense(ne0.atb) ~== Vectors.dense(90.0, 108.0) relTol 1e-8) + + intercept[IllegalArgumentException] { + ne0.add(Array(1.0f), 2.0f) + } + intercept[IllegalArgumentException] { + ne0.add(Array(1.0f, 2.0f, 3.0f), 4.0f) + } + intercept[IllegalArgumentException] { + val ne2 = new NormalEquation(3) + ne0.merge(ne2) + } + + ne0.reset() + assert(ne0.n === 0) + assert(ne0.ata.forall(_ == 0.0)) + assert(ne0.atb.forall(_ == 0.0)) + } + + test("normal equation construction with implicit feedback") { + val k = 2 + val alpha = 0.5 + val ne0 = new NormalEquation(k) + .addImplicit(Array(-5.0f, -4.0f), -3.0f, alpha) + .addImplicit(Array(-2.0f, -1.0f), 0.0f, alpha) + .addImplicit(Array(1.0f, 2.0f), 3.0f, alpha) + assert(ne0.k === k) + assert(ne0.triK === k * (k + 1) / 2) + assert(ne0.n === 0) // addImplicit doesn't increase the count. + // NumPy code that computes the expected values: + // alpha = 0.5 + // A = np.matrix("-5 -4; -2 -1; 1 2") + // b = np.matrix("-3; 0; 3") + // b1 = b > 0 + // c = 1.0 + alpha * np.abs(b) + // C = np.diag(c.A1) + // I = np.eye(3) + // ata = A.transpose() * (C - I) * A + // atb = A.transpose() * C * b1 + assert(Vectors.dense(ne0.ata) ~== Vectors.dense(39.0, 33.0, 30.0) relTol 1e-8) + assert(Vectors.dense(ne0.atb) ~== Vectors.dense(2.5, 5.0) relTol 1e-8) + } + + test("CholeskySolver") { + val k = 2 + val ne0 = new NormalEquation(k) + .add(Array(1.0f, 2.0f), 4.0f) + .add(Array(1.0f, 3.0f), 9.0f) + .add(Array(1.0f, 4.0f), 16.0f) + val ne1 = new NormalEquation(k) + .merge(ne0) + + val chol = new CholeskySolver + val x0 = chol.solve(ne0, 0.0).map(_.toDouble) + // NumPy code that computes the expected solution: + // A = np.matrix("1 2; 1 3; 1 4") + // b = b = np.matrix("3; 6") + // x0 = np.linalg.lstsq(A, b)[0] + assert(Vectors.dense(x0) ~== Vectors.dense(-8.333333, 6.0) relTol 1e-6) + + assert(ne0.n === 0) + assert(ne0.ata.forall(_ == 0.0)) + assert(ne0.atb.forall(_ == 0.0)) + + val x1 = chol.solve(ne1, 0.5).map(_.toDouble) + // NumPy code that computes the expected solution, where lambda is scaled by n: + // x0 = np.linalg.solve(A.transpose() * A + 0.5 * 3 * np.eye(2), A.transpose() * b) + assert(Vectors.dense(x1) ~== Vectors.dense(-0.1155556, 3.28) relTol 1e-6) + } + + test("RatingBlockBuilder") { + val emptyBuilder = new RatingBlockBuilder[Int]() + assert(emptyBuilder.size === 0) + val emptyBlock = emptyBuilder.build() + assert(emptyBlock.srcIds.isEmpty) + assert(emptyBlock.dstIds.isEmpty) + assert(emptyBlock.ratings.isEmpty) + + val builder0 = new RatingBlockBuilder() + .add(Rating(0, 1, 2.0f)) + .add(Rating(3, 4, 5.0f)) + assert(builder0.size === 2) + val builder1 = new RatingBlockBuilder() + .add(Rating(6, 7, 8.0f)) + .merge(builder0.build()) + assert(builder1.size === 3) + val block = builder1.build() + val ratings = Seq.tabulate(block.size) { i => + (block.srcIds(i), block.dstIds(i), block.ratings(i)) + }.toSet + assert(ratings === Set((0, 1, 2.0f), (3, 4, 5.0f), (6, 7, 8.0f))) + } + + test("UncompressedInBlock") { + val encoder = new LocalIndexEncoder(10) + val uncompressed = new UncompressedInBlockBuilder[Int](encoder) + .add(0, Array(1, 0, 2), Array(0, 1, 4), Array(1.0f, 2.0f, 3.0f)) + .add(1, Array(3, 0), Array(2, 5), Array(4.0f, 5.0f)) + .build() + assert(uncompressed.length === 5) + val records = Seq.tabulate(uncompressed.length) { i => + val dstEncodedIndex = uncompressed.dstEncodedIndices(i) + val dstBlockId = encoder.blockId(dstEncodedIndex) + val dstLocalIndex = encoder.localIndex(dstEncodedIndex) + (uncompressed.srcIds(i), dstBlockId, dstLocalIndex, uncompressed.ratings(i)) + }.toSet + val expected = + Set((1, 0, 0, 1.0f), (0, 0, 1, 2.0f), (2, 0, 4, 3.0f), (3, 1, 2, 4.0f), (0, 1, 5, 5.0f)) + assert(records === expected) + + val compressed = uncompressed.compress() + assert(compressed.size === 5) + assert(compressed.srcIds.toSeq === Seq(0, 1, 2, 3)) + assert(compressed.dstPtrs.toSeq === Seq(0, 2, 3, 4, 5)) + var decompressed = ArrayBuffer.empty[(Int, Int, Int, Float)] + var i = 0 + while (i < compressed.srcIds.size) { + var j = compressed.dstPtrs(i) + while (j < compressed.dstPtrs(i + 1)) { + val dstEncodedIndex = compressed.dstEncodedIndices(j) + val dstBlockId = encoder.blockId(dstEncodedIndex) + val dstLocalIndex = encoder.localIndex(dstEncodedIndex) + decompressed += ((compressed.srcIds(i), dstBlockId, dstLocalIndex, compressed.ratings(j))) + j += 1 + } + i += 1 + } + assert(decompressed.toSet === expected) + } + + /** + * Generates an explicit feedback dataset for testing ALS. + * @param numUsers number of users + * @param numItems number of items + * @param rank rank + * @param noiseStd the standard deviation of additive Gaussian noise on training data + * @param seed random seed + * @return (training, test) + */ + def genExplicitTestData( + numUsers: Int, + numItems: Int, + rank: Int, + noiseStd: Double = 0.0, + seed: Long = 11L): (RDD[Rating[Int]], RDD[Rating[Int]]) = { + val trainingFraction = 0.6 + val testFraction = 0.3 + val totalFraction = trainingFraction + testFraction + val random = new Random(seed) + val userFactors = genFactors(numUsers, rank, random) + val itemFactors = genFactors(numItems, rank, random) + val training = ArrayBuffer.empty[Rating[Int]] + val test = ArrayBuffer.empty[Rating[Int]] + for ((userId, userFactor) <- userFactors; (itemId, itemFactor) <- itemFactors) { + val x = random.nextDouble() + if (x < totalFraction) { + val rating = blas.sdot(rank, userFactor, 1, itemFactor, 1) + if (x < trainingFraction) { + val noise = noiseStd * random.nextGaussian() + training += Rating(userId, itemId, rating + noise.toFloat) + } else { + test += Rating(userId, itemId, rating) + } + } + } + logInfo(s"Generated an explicit feedback dataset with ${training.size} ratings for training " + + s"and ${test.size} for test.") + (sc.parallelize(training, 2), sc.parallelize(test, 2)) + } + + /** + * Generates an implicit feedback dataset for testing ALS. + * @param numUsers number of users + * @param numItems number of items + * @param rank rank + * @param noiseStd the standard deviation of additive Gaussian noise on training data + * @param seed random seed + * @return (training, test) + */ + def genImplicitTestData( + numUsers: Int, + numItems: Int, + rank: Int, + noiseStd: Double = 0.0, + seed: Long = 11L): (RDD[Rating[Int]], RDD[Rating[Int]]) = { + // The assumption of the implicit feedback model is that unobserved ratings are more likely to + // be negatives. + val positiveFraction = 0.8 + val negativeFraction = 1.0 - positiveFraction + val trainingFraction = 0.6 + val testFraction = 0.3 + val totalFraction = trainingFraction + testFraction + val random = new Random(seed) + val userFactors = genFactors(numUsers, rank, random) + val itemFactors = genFactors(numItems, rank, random) + val training = ArrayBuffer.empty[Rating[Int]] + val test = ArrayBuffer.empty[Rating[Int]] + for ((userId, userFactor) <- userFactors; (itemId, itemFactor) <- itemFactors) { + val rating = blas.sdot(rank, userFactor, 1, itemFactor, 1) + val threshold = if (rating > 0) positiveFraction else negativeFraction + val observed = random.nextDouble() < threshold + if (observed) { + val x = random.nextDouble() + if (x < totalFraction) { + if (x < trainingFraction) { + val noise = noiseStd * random.nextGaussian() + training += Rating(userId, itemId, rating + noise.toFloat) + } else { + test += Rating(userId, itemId, rating) + } + } + } + } + logInfo(s"Generated an implicit feedback dataset with ${training.size} ratings for training " + + s"and ${test.size} for test.") + (sc.parallelize(training, 2), sc.parallelize(test, 2)) + } + + /** + * Generates random user/item factors, with i.i.d. values drawn from U(a, b). + * @param size number of users/items + * @param rank number of features + * @param random random number generator + * @param a min value of the support (default: -1) + * @param b max value of the support (default: 1) + * @return a sequence of (ID, factors) pairs + */ + private def genFactors( + size: Int, + rank: Int, + random: Random, + a: Float = -1.0f, + b: Float = 1.0f): Seq[(Int, Array[Float])] = { + require(size > 0 && size < Int.MaxValue / 3) + require(b > a) + val ids = mutable.Set.empty[Int] + while (ids.size < size) { + ids += random.nextInt() + } + val width = b - a + ids.toSeq.sorted.map(id => (id, Array.fill(rank)(a + random.nextFloat() * width))) + } + + /** + * Test ALS using the given training/test splits and parameters. + * @param training training dataset + * @param test test dataset + * @param rank rank of the matrix factorization + * @param maxIter max number of iterations + * @param regParam regularization constant + * @param implicitPrefs whether to use implicit preference + * @param numUserBlocks number of user blocks + * @param numItemBlocks number of item blocks + * @param targetRMSE target test RMSE + */ + def testALS( + training: RDD[Rating[Int]], + test: RDD[Rating[Int]], + rank: Int, + maxIter: Int, + regParam: Double, + implicitPrefs: Boolean = false, + numUserBlocks: Int = 2, + numItemBlocks: Int = 3, + targetRMSE: Double = 0.05): Unit = { + val sqlContext = this.sqlContext + import sqlContext.implicits._ + val als = new ALS() + .setRank(rank) + .setRegParam(regParam) + .setImplicitPrefs(implicitPrefs) + .setNumUserBlocks(numUserBlocks) + .setNumItemBlocks(numItemBlocks) + val alpha = als.getAlpha + val model = als.fit(training.toDF()) + val predictions = model.transform(test.toDF()) + .select("rating", "prediction") + .map { case Row(rating: Float, prediction: Float) => + (rating.toDouble, prediction.toDouble) + } + val rmse = + if (implicitPrefs) { + // TODO: Use a better (rank-based?) evaluation metric for implicit feedback. + // We limit the ratings and the predictions to interval [0, 1] and compute the weighted RMSE + // with the confidence scores as weights. + val (totalWeight, weightedSumSq) = predictions.map { case (rating, prediction) => + val confidence = 1.0 + alpha * math.abs(rating) + val rating01 = math.max(math.min(rating, 1.0), 0.0) + val prediction01 = math.max(math.min(prediction, 1.0), 0.0) + val err = prediction01 - rating01 + (confidence, confidence * err * err) + }.reduce { case ((c0, e0), (c1, e1)) => + (c0 + c1, e0 + e1) + } + math.sqrt(weightedSumSq / totalWeight) + } else { + val mse = predictions.map { case (rating, prediction) => + val err = rating - prediction + err * err + }.mean() + math.sqrt(mse) + } + logInfo(s"Test RMSE is $rmse.") + assert(rmse < targetRMSE) + } + + test("exact rank-1 matrix") { + val (training, test) = genExplicitTestData(numUsers = 20, numItems = 40, rank = 1) + testALS(training, test, maxIter = 1, rank = 1, regParam = 1e-5, targetRMSE = 0.001) + testALS(training, test, maxIter = 1, rank = 2, regParam = 1e-5, targetRMSE = 0.001) + } + + test("approximate rank-1 matrix") { + val (training, test) = + genExplicitTestData(numUsers = 20, numItems = 40, rank = 1, noiseStd = 0.01) + testALS(training, test, maxIter = 2, rank = 1, regParam = 0.01, targetRMSE = 0.02) + testALS(training, test, maxIter = 2, rank = 2, regParam = 0.01, targetRMSE = 0.02) + } + + test("approximate rank-2 matrix") { + val (training, test) = + genExplicitTestData(numUsers = 20, numItems = 40, rank = 2, noiseStd = 0.01) + testALS(training, test, maxIter = 4, rank = 2, regParam = 0.01, targetRMSE = 0.03) + testALS(training, test, maxIter = 4, rank = 3, regParam = 0.01, targetRMSE = 0.03) + } + + test("different block settings") { + val (training, test) = + genExplicitTestData(numUsers = 20, numItems = 40, rank = 2, noiseStd = 0.01) + for ((numUserBlocks, numItemBlocks) <- Seq((1, 1), (1, 2), (2, 1), (2, 2))) { + testALS(training, test, maxIter = 4, rank = 3, regParam = 0.01, targetRMSE = 0.03, + numUserBlocks = numUserBlocks, numItemBlocks = numItemBlocks) + } + } + + test("more blocks than ratings") { + val (training, test) = + genExplicitTestData(numUsers = 4, numItems = 4, rank = 1) + testALS(training, test, maxIter = 2, rank = 1, regParam = 1e-4, targetRMSE = 0.002, + numItemBlocks = 5, numUserBlocks = 5) + } + + test("implicit feedback") { + val (training, test) = + genImplicitTestData(numUsers = 20, numItems = 40, rank = 2, noiseStd = 0.01) + testALS(training, test, maxIter = 4, rank = 2, regParam = 0.01, implicitPrefs = true, + targetRMSE = 0.3) + } + + test("using generic ID types") { + val (ratings, _) = genImplicitTestData(numUsers = 20, numItems = 40, rank = 2, noiseStd = 0.01) + + val longRatings = ratings.map(r => Rating(r.user.toLong, r.item.toLong, r.rating)) + val (longUserFactors, _) = ALS.train(longRatings, rank = 2, maxIter = 4) + assert(longUserFactors.first()._1.getClass === classOf[Long]) + + val strRatings = ratings.map(r => Rating(r.user.toString, r.item.toString, r.rating)) + val (strUserFactors, _) = ALS.train(strRatings, rank = 2, maxIter = 4) + assert(strUserFactors.first()._1.getClass === classOf[String]) + } + + test("nonnegative constraint") { + val (ratings, _) = genImplicitTestData(numUsers = 20, numItems = 40, rank = 2, noiseStd = 0.01) + val (userFactors, itemFactors) = ALS.train(ratings, rank = 2, maxIter = 4, nonnegative = true) + def isNonnegative(factors: RDD[(Int, Array[Float])]): Boolean = { + factors.values.map { _.forall(_ >= 0.0) }.reduce(_ && _) + } + assert(isNonnegative(userFactors)) + assert(isNonnegative(itemFactors)) + // TODO: Validate the solution. + } + + test("als partitioner is a projection") { + for (p <- Seq(1, 10, 100, 1000)) { + val part = new ALSPartitioner(p) + var k = 0 + while (k < p) { + assert(k === part.getPartition(k)) + assert(k === part.getPartition(k.toLong)) + k += 1 + } + } + } + + test("partitioner in returned factors") { + val (ratings, _) = genImplicitTestData(numUsers = 20, numItems = 40, rank = 2, noiseStd = 0.01) + val (userFactors, itemFactors) = ALS.train( + ratings, rank = 2, maxIter = 4, numUserBlocks = 3, numItemBlocks = 4) + for ((tpe, factors) <- Seq(("User", userFactors), ("Item", itemFactors))) { + assert(userFactors.partitioner.isDefined, s"$tpe factors should have partitioner.") + val part = userFactors.partitioner.get + userFactors.mapPartitionsWithIndex { (idx, items) => + items.foreach { case (id, _) => + if (part.getPartition(id) != idx) { + throw new SparkException(s"$tpe with ID $id should not be in partition $idx.") + } + } + Iterator.empty + }.count() + } + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala new file mode 100644 index 0000000000000..bbb44c3e2dfc2 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.regression + +import org.scalatest.FunSuite + +import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.{DataFrame, SQLContext} + +class LinearRegressionSuite extends FunSuite with MLlibTestSparkContext { + + @transient var sqlContext: SQLContext = _ + @transient var dataset: DataFrame = _ + + override def beforeAll(): Unit = { + super.beforeAll() + sqlContext = new SQLContext(sc) + dataset = sqlContext.createDataFrame( + sc.parallelize(generateLogisticInput(1.0, 1.0, nPoints = 100, seed = 42), 2)) + } + + test("linear regression: default params") { + val lr = new LinearRegression + assert(lr.getLabelCol == "label") + val model = lr.fit(dataset) + model.transform(dataset) + .select("label", "prediction") + .collect() + // Check defaults + assert(model.getFeaturesCol == "features") + assert(model.getPredictionCol == "prediction") + } + + test("linear regression with setters") { + // Set params, train, and check as many as we can. + val lr = new LinearRegression() + .setMaxIter(10) + .setRegParam(1.0) + val model = lr.fit(dataset) + assert(model.fittingParamMap.get(lr.maxIter).get === 10) + assert(model.fittingParamMap.get(lr.regParam).get === 1.0) + + // Call fit() with new params, and check as many as we can. + val model2 = lr.fit(dataset, lr.maxIter -> 5, lr.regParam -> 0.1, lr.predictionCol -> "thePred") + assert(model2.fittingParamMap.get(lr.maxIter).get === 5) + assert(model2.fittingParamMap.get(lr.regParam).get === 0.1) + assert(model2.getPredictionCol == "thePred") + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala index 41cc13da4d5b1..761ea821ef7c6 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala @@ -23,16 +23,16 @@ import org.apache.spark.ml.classification.LogisticRegression import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.{SQLContext, SchemaRDD} +import org.apache.spark.sql.{SQLContext, DataFrame} class CrossValidatorSuite extends FunSuite with MLlibTestSparkContext { - @transient var dataset: SchemaRDD = _ + @transient var dataset: DataFrame = _ override def beforeAll(): Unit = { super.beforeAll() val sqlContext = new SQLContext(sc) - dataset = sqlContext.createSchemaRDD( + dataset = sqlContext.createDataFrame( sc.parallelize(generateLogisticInput(1.0, 1.0, 100, 42), 2)) } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala index 94b0e00f37267..d2b40f2cae020 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala @@ -17,16 +17,19 @@ package org.apache.spark.mllib.classification -import scala.util.Random import scala.collection.JavaConversions._ +import scala.util.Random +import scala.util.control.Breaks._ import org.scalatest.FunSuite import org.scalatest.Matchers -import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.regression._ import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext} import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.util.Utils + object LogisticRegressionSuite { @@ -55,8 +58,116 @@ object LogisticRegressionSuite { val testData = (0 until nPoints).map(i => LabeledPoint(y(i), Vectors.dense(Array(x1(i))))) testData } + + /** + * Generates `k` classes multinomial synthetic logistic input in `n` dimensional space given the + * model weights and mean/variance of the features. The synthetic data will be drawn from + * the probability distribution constructed by weights using the following formula. + * + * P(y = 0 | x) = 1 / norm + * P(y = 1 | x) = exp(x * w_1) / norm + * P(y = 2 | x) = exp(x * w_2) / norm + * ... + * P(y = k-1 | x) = exp(x * w_{k-1}) / norm + * where norm = 1 + exp(x * w_1) + exp(x * w_2) + ... + exp(x * w_{k-1}) + * + * @param weights matrix is flatten into a vector; as a result, the dimension of weights vector + * will be (k - 1) * (n + 1) if `addIntercept == true`, and + * if `addIntercept != true`, the dimension will be (k - 1) * n. + * @param xMean the mean of the generated features. Lots of time, if the features are not properly + * standardized, the algorithm with poor implementation will have difficulty + * to converge. + * @param xVariance the variance of the generated features. + * @param addIntercept whether to add intercept. + * @param nPoints the number of instance of generated data. + * @param seed the seed for random generator. For consistent testing result, it will be fixed. + */ + def generateMultinomialLogisticInput( + weights: Array[Double], + xMean: Array[Double], + xVariance: Array[Double], + addIntercept: Boolean, + nPoints: Int, + seed: Int): Seq[LabeledPoint] = { + val rnd = new Random(seed) + + val xDim = xMean.size + val xWithInterceptsDim = if (addIntercept) xDim + 1 else xDim + val nClasses = weights.size / xWithInterceptsDim + 1 + + val x = Array.fill[Vector](nPoints)(Vectors.dense(Array.fill[Double](xDim)(rnd.nextGaussian()))) + + x.map(vector => { + // This doesn't work if `vector` is a sparse vector. + val vectorArray = vector.toArray + var i = 0 + while (i < vectorArray.size) { + vectorArray(i) = vectorArray(i) * math.sqrt(xVariance(i)) + xMean(i) + i += 1 + } + }) + + val y = (0 until nPoints).map { idx => + val xArray = x(idx).toArray + val margins = Array.ofDim[Double](nClasses) + val probs = Array.ofDim[Double](nClasses) + + for (i <- 0 until nClasses - 1) { + for (j <- 0 until xDim) margins(i + 1) += weights(i * xWithInterceptsDim + j) * xArray(j) + if (addIntercept) margins(i + 1) += weights((i + 1) * xWithInterceptsDim - 1) + } + // Preventing the overflow when we compute the probability + val maxMargin = margins.max + if (maxMargin > 0) for (i <-0 until nClasses) margins(i) -= maxMargin + + // Computing the probabilities for each class from the margins. + val norm = { + var temp = 0.0 + for (i <- 0 until nClasses) { + probs(i) = math.exp(margins(i)) + temp += probs(i) + } + temp + } + for (i <-0 until nClasses) probs(i) /= norm + + // Compute the cumulative probability so we can generate a random number and assign a label. + for (i <- 1 until nClasses) probs(i) += probs(i - 1) + val p = rnd.nextDouble() + var y = 0 + breakable { + for (i <- 0 until nClasses) { + if (p < probs(i)) { + y = i + break + } + } + } + y + } + + val testData = (0 until nPoints).map(i => LabeledPoint(y(i), x(i))) + testData + } + + /** Binary labels, 3 features */ + private val binaryModel = new LogisticRegressionModel( + weights = Vectors.dense(0.1, 0.2, 0.3), intercept = 0.5, numFeatures = 3, numClasses = 2) + + /** 3 classes, 2 features */ + private val multiclassModel = new LogisticRegressionModel( + weights = Vectors.dense(0.1, 0.2, 0.3, 0.4), intercept = 1.0, numFeatures = 2, numClasses = 3) + + private def checkModelsEqual(a: LogisticRegressionModel, b: LogisticRegressionModel): Unit = { + assert(a.weights == b.weights) + assert(a.intercept == b.intercept) + assert(a.numClasses == b.numClasses) + assert(a.numFeatures == b.numFeatures) + assert(a.getThreshold == b.getThreshold) + } } + class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext with Matchers { def validatePrediction( predictions: Seq[Double], @@ -285,6 +396,138 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext with M assert(modelB1.weights(0) !~== modelB3.weights(0) * 1.0E6 absTol 0.1) } + test("multinomial logistic regression with LBFGS") { + val nPoints = 10000 + + /** + * The following weights and xMean/xVariance are computed from iris dataset with lambda = 0.2. + * As a result, we are actually drawing samples from probability distribution of built model. + */ + val weights = Array( + -0.57997, 0.912083, -0.371077, -0.819866, 2.688191, + -0.16624, -0.84355, -0.048509, -0.301789, 4.170682) + + val xMean = Array(5.843, 3.057, 3.758, 1.199) + val xVariance = Array(0.6856, 0.1899, 3.116, 0.581) + + val testData = LogisticRegressionSuite.generateMultinomialLogisticInput( + weights, xMean, xVariance, true, nPoints, 42) + + val testRDD = sc.parallelize(testData, 2) + testRDD.cache() + + val lr = new LogisticRegressionWithLBFGS().setIntercept(true).setNumClasses(3) + lr.optimizer.setConvergenceTol(1E-15).setNumIterations(200) + + val model = lr.run(testRDD) + + /** + * The following is the instruction to reproduce the model using R's glmnet package. + * + * First of all, using the following scala code to save the data into `path`. + * + * testRDD.map(x => x.label+ ", " + x.features(0) + ", " + x.features(1) + ", " + + * x.features(2) + ", " + x.features(3)).saveAsTextFile("path") + * + * Using the following R code to load the data and train the model using glmnet package. + * + * library("glmnet") + * data <- read.csv("path", header=FALSE) + * label = factor(data$V1) + * features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) + * weights = coef(glmnet(features,label, family="multinomial", alpha = 0, lambda = 0)) + * + * The model weights of mutinomial logstic regression in R have `K` set of linear predictors + * for `K` classes classification problem; however, only `K-1` set is required if the first + * outcome is chosen as a "pivot", and the other `K-1` outcomes are separately regressed against + * the pivot outcome. This can be done by subtracting the first weights from those `K-1` set + * weights. The mathematical discussion and proof can be found here: + * http://en.wikipedia.org/wiki/Multinomial_logistic_regression + * + * weights1 = weights$`1` - weights$`0` + * weights2 = weights$`2` - weights$`0` + * + * > weights1 + * 5 x 1 sparse Matrix of class "dgCMatrix" + * s0 + * 2.6228269 + * data.V2 -0.5837166 + * data.V3 0.9285260 + * data.V4 -0.3783612 + * data.V5 -0.8123411 + * > weights2 + * 5 x 1 sparse Matrix of class "dgCMatrix" + * s0 + * 4.11197445 + * data.V2 -0.16918650 + * data.V3 -0.81104784 + * data.V4 -0.06463799 + * data.V5 -0.29198337 + */ + + val weightsR = Vectors.dense(Array( + -0.5837166, 0.9285260, -0.3783612, -0.8123411, 2.6228269, + -0.1691865, -0.811048, -0.0646380, -0.2919834, 4.1119745)) + + assert(model.weights ~== weightsR relTol 0.05) + + val validationData = LogisticRegressionSuite.generateMultinomialLogisticInput( + weights, xMean, xVariance, true, nPoints, 17) + val validationRDD = sc.parallelize(validationData, 2) + // The validation accuracy is not good since this model (even the original weights) doesn't have + // very steep curve in logistic function so that when we draw samples from distribution, it's + // very easy to assign to another labels. However, this prediction result is consistent to R. + validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData, 0.47) + + } + + test("model save/load: binary classification") { + // NOTE: This will need to be generalized once there are multiple model format versions. + val model = LogisticRegressionSuite.binaryModel + + model.clearThreshold() + assert(model.getThreshold.isEmpty) + + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + + // Save model, load it back, and compare. + try { + model.save(sc, path) + val sameModel = LogisticRegressionModel.load(sc, path) + LogisticRegressionSuite.checkModelsEqual(model, sameModel) + } finally { + Utils.deleteRecursively(tempDir) + } + + // Save model with threshold. + try { + model.setThreshold(0.7) + model.save(sc, path) + val sameModel = LogisticRegressionModel.load(sc, path) + LogisticRegressionSuite.checkModelsEqual(model, sameModel) + } finally { + Utils.deleteRecursively(tempDir) + } + } + + test("model save/load: multiclass classification") { + // NOTE: This will need to be generalized once there are multiple model format versions. + val model = LogisticRegressionSuite.multiclassModel + + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + + // Save model, load it back, and compare. + try { + model.save(sc, path) + val sameModel = LogisticRegressionModel.load(sc, path) + LogisticRegressionSuite.checkModelsEqual(model, sameModel) + } finally { + Utils.deleteRecursively(tempDir) + } + } + } class LogisticRegressionClusterSuite extends FunSuite with LocalClusterSparkContext { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala index 29871215f9deb..93acb424dd5a4 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala @@ -29,6 +29,8 @@ import org.apache.spark.SparkException import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext} +import org.apache.spark.util.Utils + object NaiveBayesSuite { @@ -73,6 +75,10 @@ object NaiveBayesSuite { LabeledPoint(y, Vectors.dense(xi)) } } + + /** Binary labels, 3 features */ + private val binaryModel = new NaiveBayesModel(labels = Array(0.0, 1.0), pi = Array(0.2, 0.8), + theta = Array(Array(0.1, 0.3, 0.6), Array(0.2, 0.4, 0.4)), NaiveBayesModels.Bernoulli) } class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext { @@ -138,7 +144,6 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext { Array(0.02, 0.02, 0.60, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.30) // label 2 ).map(_.map(math.log)) - val testData = NaiveBayesSuite.generateNaiveBayesInput(pi, theta, nPoints, 45, NaiveBayesModels.Bernoulli) val testRDD = sc.parallelize(testData, 2) testRDD.cache() @@ -182,6 +187,24 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext { NaiveBayes.train(sc.makeRDD(nan, 2)) } } + + test("model save/load") { + val model = NaiveBayesSuite.binaryModel + + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + + // Save model, load it back, and compare. + try { + model.save(sc, path) + val sameModel = NaiveBayesModel.load(sc, path) + assert(model.labels === sameModel.labels) + assert(model.pi === sameModel.pi) + assert(model.theta === sameModel.theta) + } finally { + Utils.deleteRecursively(tempDir) + } + } } class NaiveBayesClusterSuite extends FunSuite with LocalClusterSparkContext { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala index a2de7fbd41383..6de098b383ba3 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala @@ -27,6 +27,7 @@ import org.apache.spark.SparkException import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression._ import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext} +import org.apache.spark.util.Utils object SVMSuite { @@ -56,6 +57,9 @@ object SVMSuite { y.zip(x).map(p => LabeledPoint(p._1, Vectors.dense(p._2))) } + /** Binary labels, 3 features */ + private val binaryModel = new SVMModel(weights = Vectors.dense(0.1, 0.2, 0.3), intercept = 0.5) + } class SVMSuite extends FunSuite with MLlibTestSparkContext { @@ -191,6 +195,38 @@ class SVMSuite extends FunSuite with MLlibTestSparkContext { // Turning off data validation should not throw an exception new SVMWithSGD().setValidateData(false).run(testRDDInvalid) } + + test("model save/load") { + // NOTE: This will need to be generalized once there are multiple model format versions. + val model = SVMSuite.binaryModel + + model.clearThreshold() + assert(model.getThreshold.isEmpty) + + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + + // Save model, load it back, and compare. + try { + model.save(sc, path) + val sameModel = SVMModel.load(sc, path) + assert(model.weights == sameModel.weights) + assert(model.intercept == sameModel.intercept) + assert(sameModel.getThreshold.isEmpty) + } finally { + Utils.deleteRecursively(tempDir) + } + + // Save model with threshold. + try { + model.setThreshold(0.7) + model.save(sc, path) + val sameModel2 = SVMModel.load(sc, path) + assert(model.getThreshold.get == sameModel2.getThreshold.get) + } finally { + Utils.deleteRecursively(tempDir) + } + } } class SVMClusterSuite extends FunSuite with LocalClusterSparkContext { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala new file mode 100644 index 0000000000000..8b3e6e5ce9249 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala @@ -0,0 +1,135 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.classification + +import scala.collection.mutable.ArrayBuffer + +import org.scalatest.FunSuite + +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.streaming.dstream.DStream +import org.apache.spark.streaming.TestSuiteBase + +class StreamingLogisticRegressionSuite extends FunSuite with TestSuiteBase { + + // use longer wait time to ensure job completion + override def maxWaitTimeMillis = 30000 + + // Test if we can accurately learn B for Y = logistic(BX) on streaming data + test("parameter accuracy") { + + val nPoints = 100 + val B = 1.5 + + // create model + val model = new StreamingLogisticRegressionWithSGD() + .setInitialWeights(Vectors.dense(0.0)) + .setStepSize(0.2) + .setNumIterations(25) + + // generate sequence of simulated data + val numBatches = 20 + val input = (0 until numBatches).map { i => + LogisticRegressionSuite.generateLogisticInput(0.0, B, nPoints, 42 * (i + 1)) + } + + // apply model training to input stream + val ssc = setupStreams(input, (inputDStream: DStream[LabeledPoint]) => { + model.trainOn(inputDStream) + inputDStream.count() + }) + runStreams(ssc, numBatches, numBatches) + + // check accuracy of final parameter estimates + assert(model.latestModel().weights(0) ~== B relTol 0.1) + + } + + // Test that parameter estimates improve when learning Y = logistic(BX) on streaming data + test("parameter convergence") { + + val B = 1.5 + val nPoints = 100 + + // create model + val model = new StreamingLogisticRegressionWithSGD() + .setInitialWeights(Vectors.dense(0.0)) + .setStepSize(0.2) + .setNumIterations(25) + + // generate sequence of simulated data + val numBatches = 20 + val input = (0 until numBatches).map { i => + LogisticRegressionSuite.generateLogisticInput(0.0, B, nPoints, 42 * (i + 1)) + } + + // create buffer to store intermediate fits + val history = new ArrayBuffer[Double](numBatches) + + // apply model training to input stream, storing the intermediate results + // (we add a count to ensure the result is a DStream) + val ssc = setupStreams(input, (inputDStream: DStream[LabeledPoint]) => { + model.trainOn(inputDStream) + inputDStream.foreachRDD(x => history.append(math.abs(model.latestModel().weights(0) - B))) + inputDStream.count() + }) + runStreams(ssc, numBatches, numBatches) + + // compute change in error + val deltas = history.drop(1).zip(history.dropRight(1)) + // check error stability (it always either shrinks, or increases with small tol) + assert(deltas.forall(x => (x._1 - x._2) <= 0.1)) + // check that error shrunk on at least 2 batches + assert(deltas.map(x => if ((x._1 - x._2) < 0) 1 else 0).sum > 1) + } + + // Test predictions on a stream + test("predictions") { + + val B = 1.5 + val nPoints = 100 + + // create model initialized with true weights + val model = new StreamingLogisticRegressionWithSGD() + .setInitialWeights(Vectors.dense(1.5)) + .setStepSize(0.2) + .setNumIterations(25) + + // generate sequence of simulated data for testing + val numBatches = 10 + val testInput = (0 until numBatches).map { i => + LogisticRegressionSuite.generateLogisticInput(0.0, B, nPoints, 42 * (i + 1)) + } + + // apply model predictions to test stream + val ssc = setupStreams(testInput, (inputDStream: DStream[LabeledPoint]) => { + model.predictOnValues(inputDStream.map(x => (x.label, x.features))) + }) + + // collect the output as (true, estimated) tuples + val output: Seq[Seq[(Double, Double)]] = runStreams(ssc, numBatches, numBatches) + + // check that at least 60% of predictions are correct on all batches + val errors = output.map(batch => batch.map(p => math.abs(p._1 - p._2)).sum / nPoints) + + assert(errors.forall(x => x <= 0.4)) + } + +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/GMMExpectationMaximizationSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/GMMExpectationMaximizationSuite.scala deleted file mode 100644 index 9da5495741a80..0000000000000 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/GMMExpectationMaximizationSuite.scala +++ /dev/null @@ -1,80 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.mllib.clustering - -import org.scalatest.FunSuite - -import org.apache.spark.mllib.linalg.{Vectors, Matrices} -import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.mllib.util.TestingUtils._ - -class GMMExpectationMaximizationSuite extends FunSuite with MLlibTestSparkContext { - test("single cluster") { - val data = sc.parallelize(Array( - Vectors.dense(6.0, 9.0), - Vectors.dense(5.0, 10.0), - Vectors.dense(4.0, 11.0) - )) - - // expectations - val Ew = 1.0 - val Emu = Vectors.dense(5.0, 10.0) - val Esigma = Matrices.dense(2, 2, Array(2.0 / 3.0, -2.0 / 3.0, -2.0 / 3.0, 2.0 / 3.0)) - - val seeds = Array(314589, 29032897, 50181, 494821, 4660) - seeds.foreach { seed => - val gmm = new GaussianMixtureEM().setK(1).setSeed(seed).run(data) - assert(gmm.weight(0) ~== Ew absTol 1E-5) - assert(gmm.mu(0) ~== Emu absTol 1E-5) - assert(gmm.sigma(0) ~== Esigma absTol 1E-5) - } - } - - test("two clusters") { - val data = sc.parallelize(Array( - Vectors.dense(-5.1971), Vectors.dense(-2.5359), Vectors.dense(-3.8220), - Vectors.dense(-5.2211), Vectors.dense(-5.0602), Vectors.dense( 4.7118), - Vectors.dense( 6.8989), Vectors.dense( 3.4592), Vectors.dense( 4.6322), - Vectors.dense( 5.7048), Vectors.dense( 4.6567), Vectors.dense( 5.5026), - Vectors.dense( 4.5605), Vectors.dense( 5.2043), Vectors.dense( 6.2734) - )) - - // we set an initial gaussian to induce expected results - val initialGmm = new GaussianMixtureModel( - Array(0.5, 0.5), - Array(Vectors.dense(-1.0), Vectors.dense(1.0)), - Array(Matrices.dense(1, 1, Array(1.0)), Matrices.dense(1, 1, Array(1.0))) - ) - - val Ew = Array(1.0 / 3.0, 2.0 / 3.0) - val Emu = Array(Vectors.dense(-4.3673), Vectors.dense(5.1604)) - val Esigma = Array(Matrices.dense(1, 1, Array(1.1098)), Matrices.dense(1, 1, Array(0.86644))) - - val gmm = new GaussianMixtureEM() - .setK(2) - .setInitialModel(initialGmm) - .run(data) - - assert(gmm.weight(0) ~== Ew(0) absTol 1E-3) - assert(gmm.weight(1) ~== Ew(1) absTol 1E-3) - assert(gmm.mu(0) ~== Emu(0) absTol 1E-3) - assert(gmm.mu(1) ~== Emu(1) absTol 1E-3) - assert(gmm.sigma(0) ~== Esigma(0) absTol 1E-3) - assert(gmm.sigma(1) ~== Esigma(1) absTol 1E-3) - } -} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala new file mode 100644 index 0000000000000..1b46a4012d731 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala @@ -0,0 +1,141 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.clustering + +import org.scalatest.FunSuite + +import org.apache.spark.mllib.linalg.{Vectors, Matrices} +import org.apache.spark.mllib.stat.distribution.MultivariateGaussian +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ + +class GaussianMixtureSuite extends FunSuite with MLlibTestSparkContext { + test("single cluster") { + val data = sc.parallelize(Array( + Vectors.dense(6.0, 9.0), + Vectors.dense(5.0, 10.0), + Vectors.dense(4.0, 11.0) + )) + + // expectations + val Ew = 1.0 + val Emu = Vectors.dense(5.0, 10.0) + val Esigma = Matrices.dense(2, 2, Array(2.0 / 3.0, -2.0 / 3.0, -2.0 / 3.0, 2.0 / 3.0)) + + val seeds = Array(314589, 29032897, 50181, 494821, 4660) + seeds.foreach { seed => + val gmm = new GaussianMixture().setK(1).setSeed(seed).run(data) + assert(gmm.weights(0) ~== Ew absTol 1E-5) + assert(gmm.gaussians(0).mu ~== Emu absTol 1E-5) + assert(gmm.gaussians(0).sigma ~== Esigma absTol 1E-5) + } + + } + + test("two clusters") { + val data = sc.parallelize(Array( + Vectors.dense(-5.1971), Vectors.dense(-2.5359), Vectors.dense(-3.8220), + Vectors.dense(-5.2211), Vectors.dense(-5.0602), Vectors.dense( 4.7118), + Vectors.dense( 6.8989), Vectors.dense( 3.4592), Vectors.dense( 4.6322), + Vectors.dense( 5.7048), Vectors.dense( 4.6567), Vectors.dense( 5.5026), + Vectors.dense( 4.5605), Vectors.dense( 5.2043), Vectors.dense( 6.2734) + )) + + // we set an initial gaussian to induce expected results + val initialGmm = new GaussianMixtureModel( + Array(0.5, 0.5), + Array( + new MultivariateGaussian(Vectors.dense(-1.0), Matrices.dense(1, 1, Array(1.0))), + new MultivariateGaussian(Vectors.dense(1.0), Matrices.dense(1, 1, Array(1.0))) + ) + ) + + val Ew = Array(1.0 / 3.0, 2.0 / 3.0) + val Emu = Array(Vectors.dense(-4.3673), Vectors.dense(5.1604)) + val Esigma = Array(Matrices.dense(1, 1, Array(1.1098)), Matrices.dense(1, 1, Array(0.86644))) + + val gmm = new GaussianMixture() + .setK(2) + .setInitialModel(initialGmm) + .run(data) + + assert(gmm.weights(0) ~== Ew(0) absTol 1E-3) + assert(gmm.weights(1) ~== Ew(1) absTol 1E-3) + assert(gmm.gaussians(0).mu ~== Emu(0) absTol 1E-3) + assert(gmm.gaussians(1).mu ~== Emu(1) absTol 1E-3) + assert(gmm.gaussians(0).sigma ~== Esigma(0) absTol 1E-3) + assert(gmm.gaussians(1).sigma ~== Esigma(1) absTol 1E-3) + } + + test("single cluster with sparse data") { + val data = sc.parallelize(Array( + Vectors.sparse(3, Array(0, 2), Array(4.0, 2.0)), + Vectors.sparse(3, Array(0, 2), Array(2.0, 4.0)), + Vectors.sparse(3, Array(1), Array(6.0)) + )) + + val Ew = 1.0 + val Emu = Vectors.dense(2.0, 2.0, 2.0) + val Esigma = Matrices.dense(3, 3, + Array(8.0 / 3.0, -4.0, 4.0 / 3.0, -4.0, 8.0, -4.0, 4.0 / 3.0, -4.0, 8.0 / 3.0) + ) + + val seeds = Array(42, 1994, 27, 11, 0) + seeds.foreach { seed => + val gmm = new GaussianMixture().setK(1).setSeed(seed).run(data) + assert(gmm.weights(0) ~== Ew absTol 1E-5) + assert(gmm.gaussians(0).mu ~== Emu absTol 1E-5) + assert(gmm.gaussians(0).sigma ~== Esigma absTol 1E-5) + } + } + + test("two clusters with sparse data") { + val data = sc.parallelize(Array( + Vectors.dense(-5.1971), Vectors.dense(-2.5359), Vectors.dense(-3.8220), + Vectors.dense(-5.2211), Vectors.dense(-5.0602), Vectors.dense( 4.7118), + Vectors.dense( 6.8989), Vectors.dense( 3.4592), Vectors.dense( 4.6322), + Vectors.dense( 5.7048), Vectors.dense( 4.6567), Vectors.dense( 5.5026), + Vectors.dense( 4.5605), Vectors.dense( 5.2043), Vectors.dense( 6.2734) + )) + + val sparseData = data.map(point => Vectors.sparse(1, Array(0), point.toArray)) + // we set an initial gaussian to induce expected results + val initialGmm = new GaussianMixtureModel( + Array(0.5, 0.5), + Array( + new MultivariateGaussian(Vectors.dense(-1.0), Matrices.dense(1, 1, Array(1.0))), + new MultivariateGaussian(Vectors.dense(1.0), Matrices.dense(1, 1, Array(1.0))) + ) + ) + val Ew = Array(1.0 / 3.0, 2.0 / 3.0) + val Emu = Array(Vectors.dense(-4.3673), Vectors.dense(5.1604)) + val Esigma = Array(Matrices.dense(1, 1, Array(1.1098)), Matrices.dense(1, 1, Array(0.86644))) + + val sparseGMM = new GaussianMixture() + .setK(2) + .setInitialModel(initialGmm) + .run(data) + + assert(sparseGMM.weights(0) ~== Ew(0) absTol 1E-3) + assert(sparseGMM.weights(1) ~== Ew(1) absTol 1E-3) + assert(sparseGMM.gaussians(0).mu ~== Emu(0) absTol 1E-3) + assert(sparseGMM.gaussians(1).mu ~== Emu(1) absTol 1E-3) + assert(sparseGMM.gaussians(0).sigma ~== Esigma(0) absTol 1E-3) + assert(sparseGMM.gaussians(1).sigma ~== Esigma(1) absTol 1E-3) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala index 9ebef8466c831..caee5917000aa 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala @@ -90,6 +90,27 @@ class KMeansSuite extends FunSuite with MLlibTestSparkContext { assert(model.clusterCenters.size === 3) } + test("deterministic initialization") { + // Create a large-ish set of points for clustering + val points = List.tabulate(1000)(n => Vectors.dense(n, n)) + val rdd = sc.parallelize(points, 3) + + for (initMode <- Seq(RANDOM, K_MEANS_PARALLEL)) { + // Create three deterministic models and compare cluster means + val model1 = KMeans.train(rdd, k = 10, maxIterations = 2, runs = 1, + initializationMode = initMode, seed = 42) + val centers1 = model1.clusterCenters + + val model2 = KMeans.train(rdd, k = 10, maxIterations = 2, runs = 1, + initializationMode = initMode, seed = 42) + val centers2 = model2.clusterCenters + + centers1.zip(centers2).foreach { case (c1, c2) => + assert(c1 ~== c2 absTol 1E-14) + } + } + } + test("single cluster with big dataset") { val smallData = Array( Vectors.dense(1.0, 2.0, 6.0), diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala new file mode 100644 index 0000000000000..302d751eb8a94 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala @@ -0,0 +1,153 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.clustering + +import org.scalatest.FunSuite + +import org.apache.spark.mllib.linalg.{DenseMatrix, Matrix, Vectors} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ + +class LDASuite extends FunSuite with MLlibTestSparkContext { + + import LDASuite._ + + test("LocalLDAModel") { + val model = new LocalLDAModel(tinyTopics) + + // Check: basic parameters + assert(model.k === tinyK) + assert(model.vocabSize === tinyVocabSize) + assert(model.topicsMatrix === tinyTopics) + + // Check: describeTopics() with all terms + val fullTopicSummary = model.describeTopics() + assert(fullTopicSummary.size === tinyK) + fullTopicSummary.zip(tinyTopicDescription).foreach { + case ((algTerms, algTermWeights), (terms, termWeights)) => + assert(algTerms === terms) + assert(algTermWeights === termWeights) + } + + // Check: describeTopics() with some terms + val smallNumTerms = 3 + val smallTopicSummary = model.describeTopics(maxTermsPerTopic = smallNumTerms) + smallTopicSummary.zip(tinyTopicDescription).foreach { + case ((algTerms, algTermWeights), (terms, termWeights)) => + assert(algTerms === terms.slice(0, smallNumTerms)) + assert(algTermWeights === termWeights.slice(0, smallNumTerms)) + } + } + + test("running and DistributedLDAModel") { + val k = 3 + val topicSmoothing = 1.2 + val termSmoothing = 1.2 + + // Train a model + val lda = new LDA() + lda.setK(k) + .setDocConcentration(topicSmoothing) + .setTopicConcentration(termSmoothing) + .setMaxIterations(5) + .setSeed(12345) + val corpus = sc.parallelize(tinyCorpus, 2) + + val model: DistributedLDAModel = lda.run(corpus) + + // Check: basic parameters + val localModel = model.toLocal + assert(model.k === k) + assert(localModel.k === k) + assert(model.vocabSize === tinyVocabSize) + assert(localModel.vocabSize === tinyVocabSize) + assert(model.topicsMatrix === localModel.topicsMatrix) + + // Check: topic summaries + // The odd decimal formatting and sorting is a hack to do a robust comparison. + val roundedTopicSummary = model.describeTopics().map { case (terms, termWeights) => + // cut values to 3 digits after the decimal place + terms.zip(termWeights).map { case (term, weight) => + ("%.3f".format(weight).toDouble, term.toInt) + } + }.sortBy(_.mkString("")) + val roundedLocalTopicSummary = localModel.describeTopics().map { case (terms, termWeights) => + // cut values to 3 digits after the decimal place + terms.zip(termWeights).map { case (term, weight) => + ("%.3f".format(weight).toDouble, term.toInt) + } + }.sortBy(_.mkString("")) + roundedTopicSummary.zip(roundedLocalTopicSummary).foreach { case (t1, t2) => + assert(t1 === t2) + } + + // Check: per-doc topic distributions + val topicDistributions = model.topicDistributions.collect() + // Ensure all documents are covered. + assert(topicDistributions.size === tinyCorpus.size) + assert(tinyCorpus.map(_._1).toSet === topicDistributions.map(_._1).toSet) + // Ensure we have proper distributions + topicDistributions.foreach { case (docId, topicDistribution) => + assert(topicDistribution.size === tinyK) + assert(topicDistribution.toArray.sum ~== 1.0 absTol 1e-5) + } + + // Check: log probabilities + assert(model.logLikelihood < 0.0) + assert(model.logPrior < 0.0) + } + + test("vertex indexing") { + // Check vertex ID indexing and conversions. + val docIds = Array(0, 1, 2) + val docVertexIds = docIds + val termIds = Array(0, 1, 2) + val termVertexIds = Array(-1, -2, -3) + assert(docVertexIds.forall(i => !LDA.isTermVertex((i.toLong, 0)))) + assert(termIds.map(LDA.term2index) === termVertexIds) + assert(termVertexIds.map(i => LDA.index2term(i.toLong)) === termIds) + assert(termVertexIds.forall(i => LDA.isTermVertex((i.toLong, 0)))) + } +} + +private[clustering] object LDASuite { + + def tinyK: Int = 3 + def tinyVocabSize: Int = 5 + def tinyTopicsAsArray: Array[Array[Double]] = Array( + Array[Double](0.1, 0.2, 0.3, 0.4, 0.0), // topic 0 + Array[Double](0.5, 0.05, 0.05, 0.1, 0.3), // topic 1 + Array[Double](0.2, 0.2, 0.05, 0.05, 0.5) // topic 2 + ) + def tinyTopics: Matrix = new DenseMatrix(numRows = tinyVocabSize, numCols = tinyK, + values = tinyTopicsAsArray.fold(Array.empty[Double])(_ ++ _)) + def tinyTopicDescription: Array[(Array[Int], Array[Double])] = tinyTopicsAsArray.map { topic => + val (termWeights, terms) = topic.zipWithIndex.sortBy(-_._1).unzip + (terms.toArray, termWeights.toArray) + } + + def tinyCorpus = Array( + Vectors.dense(1, 3, 0, 2, 8), + Vectors.dense(0, 2, 1, 0, 4), + Vectors.dense(2, 3, 12, 3, 1), + Vectors.dense(0, 3, 1, 9, 8), + Vectors.dense(1, 1, 4, 2, 6) + ).zipWithIndex.map { case (wordCounts, docId) => (docId.toLong, wordCounts) } + assert(tinyCorpus.forall(_._2.size == tinyVocabSize)) // sanity check for test data + +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala new file mode 100644 index 0000000000000..6315c03a700f1 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.clustering + +import scala.collection.mutable + +import org.scalatest.FunSuite + +import org.apache.spark.graphx.{Edge, Graph} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ + +class PowerIterationClusteringSuite extends FunSuite with MLlibTestSparkContext { + + import org.apache.spark.mllib.clustering.PowerIterationClustering._ + + test("power iteration clustering") { + /* + We use the following graph to test PIC. All edges are assigned similarity 1.0 except 0.1 for + edge (3, 4). + + 15-14 -13 -12 + | | + 4 . 3 - 2 11 + | | x | | + 5 0 - 1 10 + | | + 6 - 7 - 8 - 9 + */ + + val similarities = Seq[(Long, Long, Double)]((0, 1, 1.0), (0, 2, 1.0), (0, 3, 1.0), (1, 2, 1.0), + (1, 3, 1.0), (2, 3, 1.0), (3, 4, 0.1), // (3, 4) is a weak edge + (4, 5, 1.0), (4, 15, 1.0), (5, 6, 1.0), (6, 7, 1.0), (7, 8, 1.0), (8, 9, 1.0), (9, 10, 1.0), + (10, 11, 1.0), (11, 12, 1.0), (12, 13, 1.0), (13, 14, 1.0), (14, 15, 1.0)) + val model = new PowerIterationClustering() + .setK(2) + .run(sc.parallelize(similarities, 2)) + val predictions = Array.fill(2)(mutable.Set.empty[Long]) + model.assignments.collect().foreach { a => + predictions(a.cluster) += a.id + } + assert(predictions.toSet == Set((0 to 3).toSet, (4 to 15).toSet)) + + val model2 = new PowerIterationClustering() + .setK(2) + .setInitializationMode("degree") + .run(sc.parallelize(similarities, 2)) + val predictions2 = Array.fill(2)(mutable.Set.empty[Long]) + model2.assignments.collect().foreach { a => + predictions2(a.cluster) += a.id + } + assert(predictions2.toSet == Set((0 to 3).toSet, (4 to 15).toSet)) + } + + test("normalize and powerIter") { + /* + Test normalize() with the following graph: + + 0 - 3 + | \ | + 1 - 2 + + The affinity matrix (A) is + + 0 1 1 1 + 1 0 1 0 + 1 1 0 1 + 1 0 1 0 + + D is diag(3, 2, 3, 2) and hence W is + + 0 1/3 1/3 1/3 + 1/2 0 1/2 0 + 1/3 1/3 0 1/3 + 1/2 0 1/2 0 + */ + val similarities = Seq[(Long, Long, Double)]( + (0, 1, 1.0), (0, 2, 1.0), (0, 3, 1.0), (1, 2, 1.0), (2, 3, 1.0)) + val expected = Array( + Array(0.0, 1.0/3.0, 1.0/3.0, 1.0/3.0), + Array(1.0/2.0, 0.0, 1.0/2.0, 0.0), + Array(1.0/3.0, 1.0/3.0, 0.0, 1.0/3.0), + Array(1.0/2.0, 0.0, 1.0/2.0, 0.0)) + val w = normalize(sc.parallelize(similarities, 2)) + w.edges.collect().foreach { case Edge(i, j, x) => + assert(x ~== expected(i.toInt)(j.toInt) absTol 1e-14) + } + val v0 = sc.parallelize(Seq[(Long, Double)]((0, 0.1), (1, 0.2), (2, 0.3), (3, 0.4)), 2) + val w0 = Graph(v0, w.edges) + val v1 = powerIter(w0, maxIterations = 1).collect() + val u = Array(0.3, 0.2, 0.7/3.0, 0.2) + val norm = u.sum + val u1 = u.map(x => x / norm) + v1.foreach { case (i, x) => + assert(x ~== u1(i.toInt) absTol 1e-14) + } + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala new file mode 100644 index 0000000000000..747f5914598ec --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.feature + +import org.scalatest.FunSuite + +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.util.MLlibTestSparkContext + +class ChiSqSelectorSuite extends FunSuite with MLlibTestSparkContext { + + /* + * Contingency tables + * feature0 = {8.0, 0.0} + * class 0 1 2 + * 8.0||1|0|1| + * 0.0||0|2|0| + * + * feature1 = {7.0, 9.0} + * class 0 1 2 + * 7.0||1|0|0| + * 9.0||0|2|1| + * + * feature2 = {0.0, 6.0, 8.0, 5.0} + * class 0 1 2 + * 0.0||1|0|0| + * 6.0||0|1|0| + * 8.0||0|1|0| + * 5.0||0|0|1| + * + * Use chi-squared calculator from Internet + */ + + test("ChiSqSelector transform test (sparse & dense vector)") { + val labeledDiscreteData = sc.parallelize( + Seq(LabeledPoint(0.0, Vectors.sparse(3, Array((0, 8.0), (1, 7.0)))), + LabeledPoint(1.0, Vectors.sparse(3, Array((1, 9.0), (2, 6.0)))), + LabeledPoint(1.0, Vectors.dense(Array(0.0, 9.0, 8.0))), + LabeledPoint(2.0, Vectors.dense(Array(8.0, 9.0, 5.0)))), 2) + val preFilteredData = + Set(LabeledPoint(0.0, Vectors.dense(Array(0.0))), + LabeledPoint(1.0, Vectors.dense(Array(6.0))), + LabeledPoint(1.0, Vectors.dense(Array(8.0))), + LabeledPoint(2.0, Vectors.dense(Array(5.0)))) + val model = new ChiSqSelector(1).fit(labeledDiscreteData) + val filteredData = labeledDiscreteData.map { lp => + LabeledPoint(lp.label, model.transform(lp.features)) + }.collect().toSet + assert(filteredData == preFilteredData) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala index 4c93c0ca4f86c..7f94564b2a3ae 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala @@ -22,29 +22,114 @@ import org.scalatest.FunSuite import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ -import org.apache.spark.mllib.rdd.RDDFunctions._ import org.apache.spark.mllib.stat.{MultivariateStatisticalSummary, MultivariateOnlineSummarizer} import org.apache.spark.rdd.RDD class StandardScalerSuite extends FunSuite with MLlibTestSparkContext { + // When the input data is all constant, the variance is zero. The standardization against + // zero variance is not well-defined, but we decide to just set it into zero here. + val constantData = Array( + Vectors.dense(2.0), + Vectors.dense(2.0), + Vectors.dense(2.0) + ) + + val sparseData = Array( + Vectors.sparse(3, Seq((0, -2.0), (1, 2.3))), + Vectors.sparse(3, Seq((1, -1.0), (2, -3.0))), + Vectors.sparse(3, Seq((1, -5.1))), + Vectors.sparse(3, Seq((0, 3.8), (2, 1.9))), + Vectors.sparse(3, Seq((0, 1.7), (1, -0.6))), + Vectors.sparse(3, Seq((1, 1.9))) + ) + + val denseData = Array( + Vectors.dense(-2.0, 2.3, 0), + Vectors.dense(0.0, -1.0, -3.0), + Vectors.dense(0.0, -5.1, 0.0), + Vectors.dense(3.8, 0.0, 1.9), + Vectors.dense(1.7, -0.6, 0.0), + Vectors.dense(0.0, 1.9, 0.0) + ) + private def computeSummary(data: RDD[Vector]): MultivariateStatisticalSummary = { data.treeAggregate(new MultivariateOnlineSummarizer)( (aggregator, data) => aggregator.add(data), (aggregator1, aggregator2) => aggregator1.merge(aggregator2)) } + test("Standardization with dense input when means and stds are provided") { + + val dataRDD = sc.parallelize(denseData, 3) + + val standardizer1 = new StandardScaler(withMean = true, withStd = true) + val standardizer2 = new StandardScaler() + val standardizer3 = new StandardScaler(withMean = true, withStd = false) + + val model1 = standardizer1.fit(dataRDD) + val model2 = standardizer2.fit(dataRDD) + val model3 = standardizer3.fit(dataRDD) + + val equivalentModel1 = new StandardScalerModel(model1.std, model1.mean) + val equivalentModel2 = new StandardScalerModel(model2.std, model2.mean, true, false) + val equivalentModel3 = new StandardScalerModel(model3.std, model3.mean, false, true) + + val data1 = denseData.map(equivalentModel1.transform) + val data2 = denseData.map(equivalentModel2.transform) + val data3 = denseData.map(equivalentModel3.transform) + + val data1RDD = equivalentModel1.transform(dataRDD) + val data2RDD = equivalentModel2.transform(dataRDD) + val data3RDD = equivalentModel3.transform(dataRDD) + + val summary = computeSummary(dataRDD) + val summary1 = computeSummary(data1RDD) + val summary2 = computeSummary(data2RDD) + val summary3 = computeSummary(data3RDD) + + assert((denseData, data1, data1RDD.collect()).zipped.forall { + case (v1: DenseVector, v2: DenseVector, v3: DenseVector) => true + case (v1: SparseVector, v2: SparseVector, v3: SparseVector) => true + case _ => false + }, "The vector type should be preserved after standardization.") + + assert((denseData, data2, data2RDD.collect()).zipped.forall { + case (v1: DenseVector, v2: DenseVector, v3: DenseVector) => true + case (v1: SparseVector, v2: SparseVector, v3: SparseVector) => true + case _ => false + }, "The vector type should be preserved after standardization.") + + assert((denseData, data3, data3RDD.collect()).zipped.forall { + case (v1: DenseVector, v2: DenseVector, v3: DenseVector) => true + case (v1: SparseVector, v2: SparseVector, v3: SparseVector) => true + case _ => false + }, "The vector type should be preserved after standardization.") + + assert((data1, data1RDD.collect()).zipped.forall((v1, v2) => v1 ~== v2 absTol 1E-5)) + assert((data2, data2RDD.collect()).zipped.forall((v1, v2) => v1 ~== v2 absTol 1E-5)) + assert((data3, data3RDD.collect()).zipped.forall((v1, v2) => v1 ~== v2 absTol 1E-5)) + + assert(summary1.mean ~== Vectors.dense(0.0, 0.0, 0.0) absTol 1E-5) + assert(summary1.variance ~== Vectors.dense(1.0, 1.0, 1.0) absTol 1E-5) + + assert(summary2.mean !~== Vectors.dense(0.0, 0.0, 0.0) absTol 1E-5) + assert(summary2.variance ~== Vectors.dense(1.0, 1.0, 1.0) absTol 1E-5) + + assert(summary3.mean ~== Vectors.dense(0.0, 0.0, 0.0) absTol 1E-5) + assert(summary3.variance ~== summary.variance absTol 1E-5) + + assert(data1(0) ~== Vectors.dense(-1.31527964, 1.023470449, 0.11637768424) absTol 1E-5) + assert(data1(3) ~== Vectors.dense(1.637735298, 0.156973995, 1.32247368462) absTol 1E-5) + assert(data2(4) ~== Vectors.dense(0.865538862, -0.22604255, 0.0) absTol 1E-5) + assert(data2(5) ~== Vectors.dense(0.0, 0.71580142, 0.0) absTol 1E-5) + assert(data3(1) ~== Vectors.dense(-0.58333333, -0.58333333, -2.8166666666) absTol 1E-5) + assert(data3(5) ~== Vectors.dense(-0.58333333, 2.316666666, 0.18333333333) absTol 1E-5) + } + test("Standardization with dense input") { - val data = Array( - Vectors.dense(-2.0, 2.3, 0), - Vectors.dense(0.0, -1.0, -3.0), - Vectors.dense(0.0, -5.1, 0.0), - Vectors.dense(3.8, 0.0, 1.9), - Vectors.dense(1.7, -0.6, 0.0), - Vectors.dense(0.0, 1.9, 0.0) - ) - val dataRDD = sc.parallelize(data, 3) + val dataRDD = sc.parallelize(denseData, 3) val standardizer1 = new StandardScaler(withMean = true, withStd = true) val standardizer2 = new StandardScaler() @@ -54,9 +139,9 @@ class StandardScalerSuite extends FunSuite with MLlibTestSparkContext { val model2 = standardizer2.fit(dataRDD) val model3 = standardizer3.fit(dataRDD) - val data1 = data.map(model1.transform) - val data2 = data.map(model2.transform) - val data3 = data.map(model3.transform) + val data1 = denseData.map(model1.transform) + val data2 = denseData.map(model2.transform) + val data3 = denseData.map(model3.transform) val data1RDD = model1.transform(dataRDD) val data2RDD = model2.transform(dataRDD) @@ -67,19 +152,19 @@ class StandardScalerSuite extends FunSuite with MLlibTestSparkContext { val summary2 = computeSummary(data2RDD) val summary3 = computeSummary(data3RDD) - assert((data, data1, data1RDD.collect()).zipped.forall { + assert((denseData, data1, data1RDD.collect()).zipped.forall { case (v1: DenseVector, v2: DenseVector, v3: DenseVector) => true case (v1: SparseVector, v2: SparseVector, v3: SparseVector) => true case _ => false }, "The vector type should be preserved after standardization.") - assert((data, data2, data2RDD.collect()).zipped.forall { + assert((denseData, data2, data2RDD.collect()).zipped.forall { case (v1: DenseVector, v2: DenseVector, v3: DenseVector) => true case (v1: SparseVector, v2: SparseVector, v3: SparseVector) => true case _ => false }, "The vector type should be preserved after standardization.") - assert((data, data3, data3RDD.collect()).zipped.forall { + assert((denseData, data3, data3RDD.collect()).zipped.forall { case (v1: DenseVector, v2: DenseVector, v3: DenseVector) => true case (v1: SparseVector, v2: SparseVector, v3: SparseVector) => true case _ => false @@ -107,17 +192,58 @@ class StandardScalerSuite extends FunSuite with MLlibTestSparkContext { } + test("Standardization with sparse input when means and stds are provided") { + + val dataRDD = sc.parallelize(sparseData, 3) + + val standardizer1 = new StandardScaler(withMean = true, withStd = true) + val standardizer2 = new StandardScaler() + val standardizer3 = new StandardScaler(withMean = true, withStd = false) + + val model1 = standardizer1.fit(dataRDD) + val model2 = standardizer2.fit(dataRDD) + val model3 = standardizer3.fit(dataRDD) + + val equivalentModel1 = new StandardScalerModel(model1.std, model1.mean) + val equivalentModel2 = new StandardScalerModel(model2.std, model2.mean, true, false) + val equivalentModel3 = new StandardScalerModel(model3.std, model3.mean, false, true) + + val data2 = sparseData.map(equivalentModel2.transform) + + withClue("Standardization with mean can not be applied on sparse input.") { + intercept[IllegalArgumentException] { + sparseData.map(equivalentModel1.transform) + } + } + + withClue("Standardization with mean can not be applied on sparse input.") { + intercept[IllegalArgumentException] { + sparseData.map(equivalentModel3.transform) + } + } + + val data2RDD = equivalentModel2.transform(dataRDD) + + val summary = computeSummary(data2RDD) + + assert((sparseData, data2, data2RDD.collect()).zipped.forall { + case (v1: DenseVector, v2: DenseVector, v3: DenseVector) => true + case (v1: SparseVector, v2: SparseVector, v3: SparseVector) => true + case _ => false + }, "The vector type should be preserved after standardization.") + + assert((data2, data2RDD.collect()).zipped.forall((v1, v2) => v1 ~== v2 absTol 1E-5)) + + assert(summary.mean !~== Vectors.dense(0.0, 0.0, 0.0) absTol 1E-5) + assert(summary.variance ~== Vectors.dense(1.0, 1.0, 1.0) absTol 1E-5) + + assert(data2(4) ~== Vectors.sparse(3, Seq((0, 0.865538862), (1, -0.22604255))) absTol 1E-5) + assert(data2(5) ~== Vectors.sparse(3, Seq((1, 0.71580142))) absTol 1E-5) + } + test("Standardization with sparse input") { - val data = Array( - Vectors.sparse(3, Seq((0, -2.0), (1, 2.3))), - Vectors.sparse(3, Seq((1, -1.0), (2, -3.0))), - Vectors.sparse(3, Seq((1, -5.1))), - Vectors.sparse(3, Seq((0, 3.8), (2, 1.9))), - Vectors.sparse(3, Seq((0, 1.7), (1, -0.6))), - Vectors.sparse(3, Seq((1, 1.9))) - ) - val dataRDD = sc.parallelize(data, 3) + val dataRDD = sc.parallelize(sparseData, 3) val standardizer1 = new StandardScaler(withMean = true, withStd = true) val standardizer2 = new StandardScaler() @@ -127,25 +253,26 @@ class StandardScalerSuite extends FunSuite with MLlibTestSparkContext { val model2 = standardizer2.fit(dataRDD) val model3 = standardizer3.fit(dataRDD) - val data2 = data.map(model2.transform) + val data2 = sparseData.map(model2.transform) withClue("Standardization with mean can not be applied on sparse input.") { intercept[IllegalArgumentException] { - data.map(model1.transform) + sparseData.map(model1.transform) } } withClue("Standardization with mean can not be applied on sparse input.") { intercept[IllegalArgumentException] { - data.map(model3.transform) + sparseData.map(model3.transform) } } val data2RDD = model2.transform(dataRDD) - val summary2 = computeSummary(data2RDD) - assert((data, data2, data2RDD.collect()).zipped.forall { + val summary = computeSummary(data2RDD) + + assert((sparseData, data2, data2RDD.collect()).zipped.forall { case (v1: DenseVector, v2: DenseVector, v3: DenseVector) => true case (v1: SparseVector, v2: SparseVector, v3: SparseVector) => true case _ => false @@ -153,23 +280,44 @@ class StandardScalerSuite extends FunSuite with MLlibTestSparkContext { assert((data2, data2RDD.collect()).zipped.forall((v1, v2) => v1 ~== v2 absTol 1E-5)) - assert(summary2.mean !~== Vectors.dense(0.0, 0.0, 0.0) absTol 1E-5) - assert(summary2.variance ~== Vectors.dense(1.0, 1.0, 1.0) absTol 1E-5) + assert(summary.mean !~== Vectors.dense(0.0, 0.0, 0.0) absTol 1E-5) + assert(summary.variance ~== Vectors.dense(1.0, 1.0, 1.0) absTol 1E-5) assert(data2(4) ~== Vectors.sparse(3, Seq((0, 0.865538862), (1, -0.22604255))) absTol 1E-5) assert(data2(5) ~== Vectors.sparse(3, Seq((1, 0.71580142))) absTol 1E-5) } + test("Standardization with constant input when means and stds are provided") { + + val dataRDD = sc.parallelize(constantData, 2) + + val standardizer1 = new StandardScaler(withMean = true, withStd = true) + val standardizer2 = new StandardScaler(withMean = true, withStd = false) + val standardizer3 = new StandardScaler(withMean = false, withStd = true) + + val model1 = standardizer1.fit(dataRDD) + val model2 = standardizer2.fit(dataRDD) + val model3 = standardizer3.fit(dataRDD) + + val equivalentModel1 = new StandardScalerModel(model1.std, model1.mean) + val equivalentModel2 = new StandardScalerModel(model2.std, model2.mean, true, false) + val equivalentModel3 = new StandardScalerModel(model3.std, model3.mean, false, true) + + val data1 = constantData.map(equivalentModel1.transform) + val data2 = constantData.map(equivalentModel2.transform) + val data3 = constantData.map(equivalentModel3.transform) + + assert(data1.forall(_.toArray.forall(_ == 0.0)), + "The variance is zero, so the transformed result should be 0.0") + assert(data2.forall(_.toArray.forall(_ == 0.0)), + "The variance is zero, so the transformed result should be 0.0") + assert(data3.forall(_.toArray.forall(_ == 0.0)), + "The variance is zero, so the transformed result should be 0.0") + } + test("Standardization with constant input") { - // When the input data is all constant, the variance is zero. The standardization against - // zero variance is not well-defined, but we decide to just set it into zero here. - val data = Array( - Vectors.dense(2.0), - Vectors.dense(2.0), - Vectors.dense(2.0) - ) - val dataRDD = sc.parallelize(data, 2) + val dataRDD = sc.parallelize(constantData, 2) val standardizer1 = new StandardScaler(withMean = true, withStd = true) val standardizer2 = new StandardScaler(withMean = true, withStd = false) @@ -179,9 +327,9 @@ class StandardScalerSuite extends FunSuite with MLlibTestSparkContext { val model2 = standardizer2.fit(dataRDD) val model3 = standardizer3.fit(dataRDD) - val data1 = data.map(model1.transform) - val data2 = data.map(model2.transform) - val data3 = data.map(model3.transform) + val data1 = constantData.map(model1.transform) + val data2 = constantData.map(model2.transform) + val data3 = constantData.map(model3.transform) assert(data1.forall(_.toArray.forall(_ == 0.0)), "The variance is zero, so the transformed result should be 0.0") @@ -191,4 +339,29 @@ class StandardScalerSuite extends FunSuite with MLlibTestSparkContext { "The variance is zero, so the transformed result should be 0.0") } + test("StandardScalerModel argument nulls are properly handled") { + + withClue("model needs at least one of std or mean vectors") { + intercept[IllegalArgumentException] { + val model = new StandardScalerModel(null, null) + } + } + withClue("model needs std to set withStd to true") { + intercept[IllegalArgumentException] { + val model = new StandardScalerModel(null, Vectors.dense(0.0)) + model.setWithStd(true) + } + } + withClue("model needs mean to set withMean to true") { + intercept[IllegalArgumentException] { + val model = new StandardScalerModel(Vectors.dense(0.0), null) + model.setWithMean(true) + } + } + withClue("model needs std and mean vectors to be equal size when both are provided") { + intercept[IllegalArgumentException] { + val model = new StandardScalerModel(Vectors.dense(0.0), Vectors.dense(0.0,1.0)) + } + } + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala new file mode 100644 index 0000000000000..bd5b9cc3afa10 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.mllib.fpm + +import org.scalatest.FunSuite + +import org.apache.spark.mllib.util.MLlibTestSparkContext + +class FPGrowthSuite extends FunSuite with MLlibTestSparkContext { + + + test("FP-Growth using String type") { + val transactions = Seq( + "r z h k p", + "z y x w v u t s", + "s x o n r", + "x z y m t s q e", + "z", + "x z y r q t p") + .map(_.split(" ")) + val rdd = sc.parallelize(transactions, 2).cache() + + val fpg = new FPGrowth() + + val model6 = fpg + .setMinSupport(0.9) + .setNumPartitions(1) + .run(rdd) + assert(model6.freqItemsets.count() === 0) + + val model3 = fpg + .setMinSupport(0.5) + .setNumPartitions(2) + .run(rdd) + val freqItemsets3 = model3.freqItemsets.collect().map { itemset => + (itemset.items.toSet, itemset.freq) + } + val expected = Set( + (Set("s"), 3L), (Set("z"), 5L), (Set("x"), 4L), (Set("t"), 3L), (Set("y"), 3L), + (Set("r"), 3L), + (Set("x", "z"), 3L), (Set("t", "y"), 3L), (Set("t", "x"), 3L), (Set("s", "x"), 3L), + (Set("y", "x"), 3L), (Set("y", "z"), 3L), (Set("t", "z"), 3L), + (Set("y", "x", "z"), 3L), (Set("t", "x", "z"), 3L), (Set("t", "y", "z"), 3L), + (Set("t", "y", "x"), 3L), + (Set("t", "y", "x", "z"), 3L)) + assert(freqItemsets3.toSet === expected) + + val model2 = fpg + .setMinSupport(0.3) + .setNumPartitions(4) + .run(rdd) + assert(model2.freqItemsets.count() === 54) + + val model1 = fpg + .setMinSupport(0.1) + .setNumPartitions(8) + .run(rdd) + assert(model1.freqItemsets.count() === 625) + } + + test("FP-Growth using Int type") { + val transactions = Seq( + "1 2 3", + "1 2 3 4", + "5 4 3 2 1", + "6 5 4 3 2 1", + "2 4", + "1 3", + "1 7") + .map(_.split(" ").map(_.toInt).toArray) + val rdd = sc.parallelize(transactions, 2).cache() + + val fpg = new FPGrowth() + + val model6 = fpg + .setMinSupport(0.9) + .setNumPartitions(1) + .run(rdd) + assert(model6.freqItemsets.count() === 0) + + val model3 = fpg + .setMinSupport(0.5) + .setNumPartitions(2) + .run(rdd) + assert(model3.freqItemsets.first().items.getClass === Array(1).getClass, + "frequent itemsets should use primitive arrays") + val freqItemsets3 = model3.freqItemsets.collect().map { itemset => + (itemset.items.toSet, itemset.freq) + } + val expected = Set( + (Set(1), 6L), (Set(2), 5L), (Set(3), 5L), (Set(4), 4L), + (Set(1, 2), 4L), (Set(1, 3), 5L), (Set(2, 3), 4L), + (Set(2, 4), 4L), (Set(1, 2, 3), 4L)) + assert(freqItemsets3.toSet === expected) + + val model2 = fpg + .setMinSupport(0.3) + .setNumPartitions(4) + .run(rdd) + assert(model2.freqItemsets.count() === 15) + + val model1 = fpg + .setMinSupport(0.1) + .setNumPartitions(8) + .run(rdd) + assert(model1.freqItemsets.count() === 65) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPTreeSuite.scala new file mode 100644 index 0000000000000..04017f67c311d --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPTreeSuite.scala @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.fpm + +import scala.language.existentials + +import org.scalatest.FunSuite + +import org.apache.spark.mllib.util.MLlibTestSparkContext + +class FPTreeSuite extends FunSuite with MLlibTestSparkContext { + + test("add transaction") { + val tree = new FPTree[String] + .add(Seq("a", "b", "c")) + .add(Seq("a", "b", "y")) + .add(Seq("b")) + + assert(tree.root.children.size == 2) + assert(tree.root.children.contains("a")) + assert(tree.root.children("a").item.equals("a")) + assert(tree.root.children("a").count == 2) + assert(tree.root.children.contains("b")) + assert(tree.root.children("b").item.equals("b")) + assert(tree.root.children("b").count == 1) + var child = tree.root.children("a") + assert(child.children.size == 1) + assert(child.children.contains("b")) + assert(child.children("b").item.equals("b")) + assert(child.children("b").count == 2) + child = child.children("b") + assert(child.children.size == 2) + assert(child.children.contains("c")) + assert(child.children.contains("y")) + assert(child.children("c").item.equals("c")) + assert(child.children("y").item.equals("y")) + assert(child.children("c").count == 1) + assert(child.children("y").count == 1) + } + + test("merge tree") { + val tree1 = new FPTree[String] + .add(Seq("a", "b", "c")) + .add(Seq("a", "b", "y")) + .add(Seq("b")) + + val tree2 = new FPTree[String] + .add(Seq("a", "b")) + .add(Seq("a", "b", "c")) + .add(Seq("a", "b", "c", "d")) + .add(Seq("a", "x")) + .add(Seq("a", "x", "y")) + .add(Seq("c", "n")) + .add(Seq("c", "m")) + + val tree3 = tree1.merge(tree2) + + assert(tree3.root.children.size == 3) + assert(tree3.root.children("a").count == 7) + assert(tree3.root.children("b").count == 1) + assert(tree3.root.children("c").count == 2) + val child1 = tree3.root.children("a") + assert(child1.children.size == 2) + assert(child1.children("b").count == 5) + assert(child1.children("x").count == 2) + val child2 = child1.children("b") + assert(child2.children.size == 2) + assert(child2.children("y").count == 1) + assert(child2.children("c").count == 3) + val child3 = child2.children("c") + assert(child3.children.size == 1) + assert(child3.children("d").count == 1) + val child4 = child1.children("x") + assert(child4.children.size == 1) + assert(child4.children("y").count == 1) + val child5 = tree3.root.children("c") + assert(child5.children.size == 2) + assert(child5.children("n").count == 1) + assert(child5.children("m").count == 1) + } + + test("extract freq itemsets") { + val tree = new FPTree[String] + .add(Seq("a", "b", "c")) + .add(Seq("a", "b", "y")) + .add(Seq("a", "b")) + .add(Seq("a")) + .add(Seq("b")) + .add(Seq("b", "n")) + + val freqItemsets = tree.extract(3L).map { case (items, count) => + (items.toSet, count) + }.toSet + val expected = Set( + (Set("a"), 4L), + (Set("b"), 5L), + (Set("a", "b"), 3L)) + assert(freqItemsets === expected) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala new file mode 100644 index 0000000000000..699f009f0f2ec --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala @@ -0,0 +1,187 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.impl + +import org.scalatest.FunSuite + +import org.apache.hadoop.fs.{FileSystem, Path} + +import org.apache.spark.SparkContext +import org.apache.spark.graphx.{Edge, Graph} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.storage.StorageLevel +import org.apache.spark.util.Utils + + +class PeriodicGraphCheckpointerSuite extends FunSuite with MLlibTestSparkContext { + + import PeriodicGraphCheckpointerSuite._ + + // TODO: Do I need to call count() on the graphs' RDDs? + + test("Persisting") { + var graphsToCheck = Seq.empty[GraphToCheck] + + val graph1 = createGraph(sc) + val checkpointer = new PeriodicGraphCheckpointer(graph1, 10) + graphsToCheck = graphsToCheck :+ GraphToCheck(graph1, 1) + checkPersistence(graphsToCheck, 1) + + var iteration = 2 + while (iteration < 9) { + val graph = createGraph(sc) + checkpointer.updateGraph(graph) + graphsToCheck = graphsToCheck :+ GraphToCheck(graph, iteration) + checkPersistence(graphsToCheck, iteration) + iteration += 1 + } + } + + test("Checkpointing") { + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + val checkpointInterval = 2 + var graphsToCheck = Seq.empty[GraphToCheck] + sc.setCheckpointDir(path) + val graph1 = createGraph(sc) + val checkpointer = new PeriodicGraphCheckpointer(graph1, checkpointInterval) + graph1.edges.count() + graph1.vertices.count() + graphsToCheck = graphsToCheck :+ GraphToCheck(graph1, 1) + checkCheckpoint(graphsToCheck, 1, checkpointInterval) + + var iteration = 2 + while (iteration < 9) { + val graph = createGraph(sc) + checkpointer.updateGraph(graph) + graph.vertices.count() + graph.edges.count() + graphsToCheck = graphsToCheck :+ GraphToCheck(graph, iteration) + checkCheckpoint(graphsToCheck, iteration, checkpointInterval) + iteration += 1 + } + + checkpointer.deleteAllCheckpoints() + graphsToCheck.foreach { graph => + confirmCheckpointRemoved(graph.graph) + } + + Utils.deleteRecursively(tempDir) + } +} + +private object PeriodicGraphCheckpointerSuite { + + case class GraphToCheck(graph: Graph[Double, Double], gIndex: Int) + + val edges = Seq( + Edge[Double](0, 1, 0), + Edge[Double](1, 2, 0), + Edge[Double](2, 3, 0), + Edge[Double](3, 4, 0)) + + def createGraph(sc: SparkContext): Graph[Double, Double] = { + Graph.fromEdges[Double, Double](sc.parallelize(edges), 0) + } + + def checkPersistence(graphs: Seq[GraphToCheck], iteration: Int): Unit = { + graphs.foreach { g => + checkPersistence(g.graph, g.gIndex, iteration) + } + } + + /** + * Check storage level of graph. + * @param gIndex Index of graph in order inserted into checkpointer (from 1). + * @param iteration Total number of graphs inserted into checkpointer. + */ + def checkPersistence(graph: Graph[_, _], gIndex: Int, iteration: Int): Unit = { + try { + if (gIndex + 2 < iteration) { + assert(graph.vertices.getStorageLevel == StorageLevel.NONE) + assert(graph.edges.getStorageLevel == StorageLevel.NONE) + } else { + assert(graph.vertices.getStorageLevel != StorageLevel.NONE) + assert(graph.edges.getStorageLevel != StorageLevel.NONE) + } + } catch { + case _: AssertionError => + throw new Exception(s"PeriodicGraphCheckpointerSuite.checkPersistence failed with:\n" + + s"\t gIndex = $gIndex\n" + + s"\t iteration = $iteration\n" + + s"\t graph.vertices.getStorageLevel = ${graph.vertices.getStorageLevel}\n" + + s"\t graph.edges.getStorageLevel = ${graph.edges.getStorageLevel}\n") + } + } + + def checkCheckpoint(graphs: Seq[GraphToCheck], iteration: Int, checkpointInterval: Int): Unit = { + graphs.reverse.foreach { g => + checkCheckpoint(g.graph, g.gIndex, iteration, checkpointInterval) + } + } + + def confirmCheckpointRemoved(graph: Graph[_, _]): Unit = { + // Note: We cannot check graph.isCheckpointed since that value is never updated. + // Instead, we check for the presence of the checkpoint files. + // This test should continue to work even after this graph.isCheckpointed issue + // is fixed (though it can then be simplified and not look for the files). + val fs = FileSystem.get(graph.vertices.sparkContext.hadoopConfiguration) + graph.getCheckpointFiles.foreach { checkpointFile => + assert(!fs.exists(new Path(checkpointFile)), + "Graph checkpoint file should have been removed") + } + } + + /** + * Check checkpointed status of graph. + * @param gIndex Index of graph in order inserted into checkpointer (from 1). + * @param iteration Total number of graphs inserted into checkpointer. + */ + def checkCheckpoint( + graph: Graph[_, _], + gIndex: Int, + iteration: Int, + checkpointInterval: Int): Unit = { + try { + if (gIndex % checkpointInterval == 0) { + // We allow 2 checkpoint intervals since we perform an action (checkpointing a second graph) + // only AFTER PeriodicGraphCheckpointer decides whether to remove the previous checkpoint. + if (iteration - 2 * checkpointInterval < gIndex && gIndex <= iteration) { + assert(graph.isCheckpointed, "Graph should be checkpointed") + assert(graph.getCheckpointFiles.length == 2, "Graph should have 2 checkpoint files") + } else { + confirmCheckpointRemoved(graph) + } + } else { + // Graph should never be checkpointed + assert(!graph.isCheckpointed, "Graph should never have been checkpointed") + assert(graph.getCheckpointFiles.length == 0, "Graph should not have any checkpoint files") + } + } catch { + case e: AssertionError => + throw new Exception(s"PeriodicGraphCheckpointerSuite.checkCheckpoint failed with:\n" + + s"\t gIndex = $gIndex\n" + + s"\t iteration = $iteration\n" + + s"\t checkpointInterval = $checkpointInterval\n" + + s"\t graph.isCheckpointed = ${graph.isCheckpointed}\n" + + s"\t graph.getCheckpointFiles = ${graph.getCheckpointFiles.mkString(", ")}\n" + + s" AssertionError message: ${e.getMessage}") + } + } + +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala index 771878e925ea7..002cb253862b5 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala @@ -166,19 +166,28 @@ class BLASSuite extends FunSuite { syr(alpha, y, dA) } } + + val xSparse = new SparseVector(4, Array(0, 2, 3), Array(1.0, 3.0, 4.0)) + val dD = new DenseMatrix(4, 4, + Array(0.0, 1.2, 2.2, 3.1, 1.2, 3.2, 5.3, 4.6, 2.2, 5.3, 1.8, 3.0, 3.1, 4.6, 3.0, 0.8)) + syr(0.1, xSparse, dD) + val expectedSparse = new DenseMatrix(4, 4, + Array(0.1, 1.2, 2.5, 3.5, 1.2, 3.2, 5.3, 4.6, 2.5, 5.3, 2.7, 4.2, 3.5, 4.6, 4.2, 2.4)) + assert(dD ~== expectedSparse absTol 1e-15) } test("gemm") { - val dA = new DenseMatrix(4, 3, Array(0.0, 1.0, 0.0, 0.0, 2.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 3.0)) val sA = new SparseMatrix(4, 3, Array(0, 1, 3, 4), Array(1, 0, 2, 3), Array(1.0, 2.0, 1.0, 3.0)) val B = new DenseMatrix(3, 2, Array(1.0, 0.0, 0.0, 0.0, 2.0, 1.0)) val expected = new DenseMatrix(4, 2, Array(0.0, 1.0, 0.0, 0.0, 4.0, 0.0, 2.0, 3.0)) + val BTman = new DenseMatrix(2, 3, Array(1.0, 0.0, 0.0, 2.0, 0.0, 1.0)) + val BT = B.transpose - assert(dA multiply B ~== expected absTol 1e-15) - assert(sA multiply B ~== expected absTol 1e-15) + assert(dA.multiply(B) ~== expected absTol 1e-15) + assert(sA.multiply(B) ~== expected absTol 1e-15) val C1 = new DenseMatrix(4, 2, Array(1.0, 0.0, 2.0, 1.0, 0.0, 0.0, 1.0, 0.0)) val C2 = C1.copy @@ -188,6 +197,10 @@ class BLASSuite extends FunSuite { val C6 = C1.copy val C7 = C1.copy val C8 = C1.copy + val C9 = C1.copy + val C10 = C1.copy + val C11 = C1.copy + val C12 = C1.copy val expected2 = new DenseMatrix(4, 2, Array(2.0, 1.0, 4.0, 2.0, 4.0, 0.0, 4.0, 3.0)) val expected3 = new DenseMatrix(4, 2, Array(2.0, 2.0, 4.0, 2.0, 8.0, 0.0, 6.0, 6.0)) @@ -202,26 +215,40 @@ class BLASSuite extends FunSuite { withClue("columns of A don't match the rows of B") { intercept[Exception] { - gemm(true, false, 1.0, dA, B, 2.0, C1) + gemm(1.0, dA.transpose, B, 2.0, C1) } } - val dAT = + val dATman = new DenseMatrix(3, 4, Array(0.0, 2.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 3.0)) - val sAT = + val sATman = new SparseMatrix(3, 4, Array(0, 1, 2, 3, 4), Array(1, 0, 1, 2), Array(2.0, 1.0, 1.0, 3.0)) - assert(dAT transposeMultiply B ~== expected absTol 1e-15) - assert(sAT transposeMultiply B ~== expected absTol 1e-15) - - gemm(true, false, 1.0, dAT, B, 2.0, C5) - gemm(true, false, 1.0, sAT, B, 2.0, C6) - gemm(true, false, 2.0, dAT, B, 2.0, C7) - gemm(true, false, 2.0, sAT, B, 2.0, C8) + val dATT = dATman.transpose + val sATT = sATman.transpose + val BTT = BTman.transpose.asInstanceOf[DenseMatrix] + + assert(dATT.multiply(B) ~== expected absTol 1e-15) + assert(sATT.multiply(B) ~== expected absTol 1e-15) + assert(dATT.multiply(BTT) ~== expected absTol 1e-15) + assert(sATT.multiply(BTT) ~== expected absTol 1e-15) + + gemm(1.0, dATT, BTT, 2.0, C5) + gemm(1.0, sATT, BTT, 2.0, C6) + gemm(2.0, dATT, BTT, 2.0, C7) + gemm(2.0, sATT, BTT, 2.0, C8) + gemm(1.0, dA, BTT, 2.0, C9) + gemm(1.0, sA, BTT, 2.0, C10) + gemm(2.0, dA, BTT, 2.0, C11) + gemm(2.0, sA, BTT, 2.0, C12) assert(C5 ~== expected2 absTol 1e-15) assert(C6 ~== expected2 absTol 1e-15) assert(C7 ~== expected3 absTol 1e-15) assert(C8 ~== expected3 absTol 1e-15) + assert(C9 ~== expected2 absTol 1e-15) + assert(C10 ~== expected2 absTol 1e-15) + assert(C11 ~== expected3 absTol 1e-15) + assert(C12 ~== expected3 absTol 1e-15) } test("gemv") { @@ -233,17 +260,13 @@ class BLASSuite extends FunSuite { val x = new DenseVector(Array(1.0, 2.0, 3.0)) val expected = new DenseVector(Array(4.0, 1.0, 2.0, 9.0)) - assert(dA multiply x ~== expected absTol 1e-15) - assert(sA multiply x ~== expected absTol 1e-15) + assert(dA.multiply(x) ~== expected absTol 1e-15) + assert(sA.multiply(x) ~== expected absTol 1e-15) val y1 = new DenseVector(Array(1.0, 3.0, 1.0, 0.0)) val y2 = y1.copy val y3 = y1.copy val y4 = y1.copy - val y5 = y1.copy - val y6 = y1.copy - val y7 = y1.copy - val y8 = y1.copy val expected2 = new DenseVector(Array(6.0, 7.0, 4.0, 9.0)) val expected3 = new DenseVector(Array(10.0, 8.0, 6.0, 18.0)) @@ -257,25 +280,18 @@ class BLASSuite extends FunSuite { assert(y4 ~== expected3 absTol 1e-15) withClue("columns of A don't match the rows of B") { intercept[Exception] { - gemv(true, 1.0, dA, x, 2.0, y1) + gemv(1.0, dA.transpose, x, 2.0, y1) } } - val dAT = new DenseMatrix(3, 4, Array(0.0, 2.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 3.0)) val sAT = new SparseMatrix(3, 4, Array(0, 1, 2, 3, 4), Array(1, 0, 1, 2), Array(2.0, 1.0, 1.0, 3.0)) - assert(dAT transposeMultiply x ~== expected absTol 1e-15) - assert(sAT transposeMultiply x ~== expected absTol 1e-15) - - gemv(true, 1.0, dAT, x, 2.0, y5) - gemv(true, 1.0, sAT, x, 2.0, y6) - gemv(true, 2.0, dAT, x, 2.0, y7) - gemv(true, 2.0, sAT, x, 2.0, y8) - assert(y5 ~== expected2 absTol 1e-15) - assert(y6 ~== expected2 absTol 1e-15) - assert(y7 ~== expected3 absTol 1e-15) - assert(y8 ~== expected3 absTol 1e-15) + val dATT = dAT.transpose + val sATT = sAT.transpose + + assert(dATT.multiply(x) ~== expected absTol 1e-15) + assert(sATT.multiply(x) ~== expected absTol 1e-15) } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeMatrixConversionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeMatrixConversionSuite.scala index 73a6d3a27d868..2031032373971 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeMatrixConversionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeMatrixConversionSuite.scala @@ -36,6 +36,11 @@ class BreezeMatrixConversionSuite extends FunSuite { assert(mat.numRows === breeze.rows) assert(mat.numCols === breeze.cols) assert(mat.values.eq(breeze.data), "should not copy data") + // transposed matrix + val matTransposed = Matrices.fromBreeze(breeze.t).asInstanceOf[DenseMatrix] + assert(matTransposed.numRows === breeze.cols) + assert(matTransposed.numCols === breeze.rows) + assert(matTransposed.values.eq(breeze.data), "should not copy data") } test("sparse matrix to breeze") { @@ -58,5 +63,9 @@ class BreezeMatrixConversionSuite extends FunSuite { assert(mat.numRows === breeze.rows) assert(mat.numCols === breeze.cols) assert(mat.values.eq(breeze.data), "should not copy data") + val matTransposed = Matrices.fromBreeze(breeze.t).asInstanceOf[SparseMatrix] + assert(matTransposed.numRows === breeze.cols) + assert(matTransposed.numCols === breeze.rows) + assert(!matTransposed.values.eq(breeze.data), "has to copy data") } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala index a35d0fe389fdd..c098b5458fe6b 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala @@ -22,6 +22,9 @@ import java.util.Random import org.mockito.Mockito.when import org.scalatest.FunSuite import org.scalatest.mock.MockitoSugar._ +import scala.collection.mutable.{Map => MutableMap} + +import org.apache.spark.mllib.util.TestingUtils._ class MatricesSuite extends FunSuite { test("dense matrix construction") { @@ -32,7 +35,6 @@ class MatricesSuite extends FunSuite { assert(mat.numRows === m) assert(mat.numCols === n) assert(mat.values.eq(values), "should not copy data") - assert(mat.toArray.eq(values), "toArray should not copy data") } test("dense matrix construction with wrong dimension") { @@ -135,8 +137,8 @@ class MatricesSuite extends FunSuite { val spMat1 = new SparseMatrix(m, n, colPtrs, rowIndices, values) val deMat1 = new DenseMatrix(m, n, allValues) - val spMat2 = deMat1.toSparse() - val deMat2 = spMat1.toDense() + val spMat2 = deMat1.toSparse + val deMat2 = spMat1.toDense assert(spMat1.toBreeze === spMat2.toBreeze) assert(deMat1.toBreeze === deMat2.toBreeze) @@ -161,6 +163,66 @@ class MatricesSuite extends FunSuite { assert(deMat1.toArray === deMat2.toArray) } + test("transpose") { + val dA = + new DenseMatrix(4, 3, Array(0.0, 1.0, 0.0, 0.0, 2.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 3.0)) + val sA = new SparseMatrix(4, 3, Array(0, 1, 3, 4), Array(1, 0, 2, 3), Array(1.0, 2.0, 1.0, 3.0)) + + val dAT = dA.transpose.asInstanceOf[DenseMatrix] + val sAT = sA.transpose.asInstanceOf[SparseMatrix] + val dATexpected = + new DenseMatrix(3, 4, Array(0.0, 2.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 3.0)) + val sATexpected = + new SparseMatrix(3, 4, Array(0, 1, 2, 3, 4), Array(1, 0, 1, 2), Array(2.0, 1.0, 1.0, 3.0)) + + assert(dAT.toBreeze === dATexpected.toBreeze) + assert(sAT.toBreeze === sATexpected.toBreeze) + assert(dA(1, 0) === dAT(0, 1)) + assert(dA(2, 1) === dAT(1, 2)) + assert(sA(1, 0) === sAT(0, 1)) + assert(sA(2, 1) === sAT(1, 2)) + + assert(!dA.toArray.eq(dAT.toArray), "has to have a new array") + assert(dA.values.eq(dAT.transpose.asInstanceOf[DenseMatrix].values), "should not copy array") + + assert(dAT.toSparse.toBreeze === sATexpected.toBreeze) + assert(sAT.toDense.toBreeze === dATexpected.toBreeze) + } + + test("foreachActive") { + val m = 3 + val n = 2 + val values = Array(1.0, 2.0, 4.0, 5.0) + val allValues = Array(1.0, 2.0, 0.0, 0.0, 4.0, 5.0) + val colPtrs = Array(0, 2, 4) + val rowIndices = Array(0, 1, 1, 2) + + val sp = new SparseMatrix(m, n, colPtrs, rowIndices, values) + val dn = new DenseMatrix(m, n, allValues) + + val dnMap = MutableMap[(Int, Int), Double]() + dn.foreachActive { (i, j, value) => + dnMap.put((i, j), value) + } + assert(dnMap.size === 6) + assert(dnMap(0, 0) === 1.0) + assert(dnMap(1, 0) === 2.0) + assert(dnMap(2, 0) === 0.0) + assert(dnMap(0, 1) === 0.0) + assert(dnMap(1, 1) === 4.0) + assert(dnMap(2, 1) === 5.0) + + val spMap = MutableMap[(Int, Int), Double]() + sp.foreachActive { (i, j, value) => + spMap.put((i, j), value) + } + assert(spMap.size === 4) + assert(spMap(0, 0) === 1.0) + assert(spMap(1, 0) === 2.0) + assert(spMap(1, 1) === 4.0) + assert(spMap(2, 1) === 5.0) + } + test("horzcat, vertcat, eye, speye") { val m = 3 val n = 2 @@ -168,9 +230,20 @@ class MatricesSuite extends FunSuite { val allValues = Array(1.0, 2.0, 0.0, 0.0, 4.0, 5.0) val colPtrs = Array(0, 2, 4) val rowIndices = Array(0, 1, 1, 2) + // transposed versions + val allValuesT = Array(1.0, 0.0, 2.0, 4.0, 0.0, 5.0) + val colPtrsT = Array(0, 1, 3, 4) + val rowIndicesT = Array(0, 0, 1, 1) val spMat1 = new SparseMatrix(m, n, colPtrs, rowIndices, values) val deMat1 = new DenseMatrix(m, n, allValues) + val spMat1T = new SparseMatrix(n, m, colPtrsT, rowIndicesT, values) + val deMat1T = new DenseMatrix(n, m, allValuesT) + + // should equal spMat1 & deMat1 respectively + val spMat1TT = spMat1T.transpose + val deMat1TT = deMat1T.transpose + val deMat2 = Matrices.eye(3) val spMat2 = Matrices.speye(3) val deMat3 = Matrices.eye(2) @@ -180,7 +253,6 @@ class MatricesSuite extends FunSuite { val spHorz2 = Matrices.horzcat(Array(spMat1, deMat2)) val spHorz3 = Matrices.horzcat(Array(deMat1, spMat2)) val deHorz1 = Matrices.horzcat(Array(deMat1, deMat2)) - val deHorz2 = Matrices.horzcat(Array[Matrix]()) assert(deHorz1.numRows === 3) @@ -195,8 +267,8 @@ class MatricesSuite extends FunSuite { assert(deHorz2.numCols === 0) assert(deHorz2.toArray.length === 0) - assert(deHorz1.toBreeze.toDenseMatrix === spHorz2.toBreeze.toDenseMatrix) - assert(spHorz2.toBreeze === spHorz3.toBreeze) + assert(deHorz1 ~== spHorz2.asInstanceOf[SparseMatrix].toDense absTol 1e-15) + assert(spHorz2 ~== spHorz3 absTol 1e-15) assert(spHorz(0, 0) === 1.0) assert(spHorz(2, 1) === 5.0) assert(spHorz(0, 2) === 1.0) @@ -212,6 +284,17 @@ class MatricesSuite extends FunSuite { assert(deHorz1(2, 4) === 1.0) assert(deHorz1(1, 4) === 0.0) + // containing transposed matrices + val spHorzT = Matrices.horzcat(Array(spMat1TT, spMat2)) + val spHorz2T = Matrices.horzcat(Array(spMat1TT, deMat2)) + val spHorz3T = Matrices.horzcat(Array(deMat1TT, spMat2)) + val deHorz1T = Matrices.horzcat(Array(deMat1TT, deMat2)) + + assert(deHorz1T ~== deHorz1 absTol 1e-15) + assert(spHorzT ~== spHorz absTol 1e-15) + assert(spHorz2T ~== spHorz2 absTol 1e-15) + assert(spHorz3T ~== spHorz3 absTol 1e-15) + intercept[IllegalArgumentException] { Matrices.horzcat(Array(spMat1, spMat3)) } @@ -238,8 +321,8 @@ class MatricesSuite extends FunSuite { assert(deVert2.numCols === 0) assert(deVert2.toArray.length === 0) - assert(deVert1.toBreeze.toDenseMatrix === spVert2.toBreeze.toDenseMatrix) - assert(spVert2.toBreeze === spVert3.toBreeze) + assert(deVert1 ~== spVert2.asInstanceOf[SparseMatrix].toDense absTol 1e-15) + assert(spVert2 ~== spVert3 absTol 1e-15) assert(spVert(0, 0) === 1.0) assert(spVert(2, 1) === 5.0) assert(spVert(3, 0) === 1.0) @@ -251,6 +334,17 @@ class MatricesSuite extends FunSuite { assert(deVert1(3, 1) === 0.0) assert(deVert1(4, 1) === 1.0) + // containing transposed matrices + val spVertT = Matrices.vertcat(Array(spMat1TT, spMat3)) + val deVert1T = Matrices.vertcat(Array(deMat1TT, deMat3)) + val spVert2T = Matrices.vertcat(Array(spMat1TT, deMat3)) + val spVert3T = Matrices.vertcat(Array(deMat1TT, spMat3)) + + assert(deVert1T ~== deVert1 absTol 1e-15) + assert(spVertT ~== spVert absTol 1e-15) + assert(spVert2T ~== spVert2 absTol 1e-15) + assert(spVert3T ~== spVert3 absTol 1e-15) + intercept[IllegalArgumentException] { Matrices.vertcat(Array(spMat1, spMat2)) } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala index 85ac8ccebfc59..5def899cea117 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala @@ -89,6 +89,24 @@ class VectorsSuite extends FunSuite { } } + test("vectors equals with explicit 0") { + val dv1 = Vectors.dense(Array(0, 0.9, 0, 0.8, 0)) + val sv1 = Vectors.sparse(5, Array(1, 3), Array(0.9, 0.8)) + val sv2 = Vectors.sparse(5, Array(0, 1, 2, 3, 4), Array(0, 0.9, 0, 0.8, 0)) + + val vectors = Seq(dv1, sv1, sv2) + for (v <- vectors; u <- vectors) { + assert(v === u) + assert(v.## === u.##) + } + + val another = Vectors.sparse(5, Array(0, 1, 3), Array(0, 0.9, 0.2)) + for (v <- vectors) { + assert(v != another) + assert(v.## != another.##) + } + } + test("indexing dense vectors") { val vec = Vectors.dense(1.0, 2.0, 3.0, 4.0) assert(vec(0) === 1.0) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala new file mode 100644 index 0000000000000..949d1c9939570 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala @@ -0,0 +1,298 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.linalg.distributed + +import java.{util => ju} + +import breeze.linalg.{DenseMatrix => BDM} +import org.scalatest.FunSuite + +import org.apache.spark.SparkException +import org.apache.spark.mllib.linalg.{SparseMatrix, DenseMatrix, Matrices, Matrix} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ + +class BlockMatrixSuite extends FunSuite with MLlibTestSparkContext { + + val m = 5 + val n = 4 + val rowPerPart = 2 + val colPerPart = 2 + val numPartitions = 3 + var gridBasedMat: BlockMatrix = _ + + override def beforeAll() { + super.beforeAll() + val blocks: Seq[((Int, Int), Matrix)] = Seq( + ((0, 0), new DenseMatrix(2, 2, Array(1.0, 0.0, 0.0, 2.0))), + ((0, 1), new DenseMatrix(2, 2, Array(0.0, 1.0, 0.0, 0.0))), + ((1, 0), new DenseMatrix(2, 2, Array(3.0, 0.0, 1.0, 1.0))), + ((1, 1), new DenseMatrix(2, 2, Array(1.0, 2.0, 0.0, 1.0))), + ((2, 1), new DenseMatrix(1, 2, Array(1.0, 5.0)))) + + gridBasedMat = new BlockMatrix(sc.parallelize(blocks, numPartitions), rowPerPart, colPerPart) + } + + test("size") { + assert(gridBasedMat.numRows() === m) + assert(gridBasedMat.numCols() === n) + } + + test("grid partitioner") { + val random = new ju.Random() + // This should generate a 4x4 grid of 1x2 blocks. + val part0 = GridPartitioner(4, 7, suggestedNumPartitions = 12) + val expected0 = Array( + Array(0, 0, 4, 4, 8, 8, 12), + Array(1, 1, 5, 5, 9, 9, 13), + Array(2, 2, 6, 6, 10, 10, 14), + Array(3, 3, 7, 7, 11, 11, 15)) + for (i <- 0 until 4; j <- 0 until 7) { + assert(part0.getPartition((i, j)) === expected0(i)(j)) + assert(part0.getPartition((i, j, random.nextInt())) === expected0(i)(j)) + } + + intercept[IllegalArgumentException] { + part0.getPartition((-1, 0)) + } + + intercept[IllegalArgumentException] { + part0.getPartition((4, 0)) + } + + intercept[IllegalArgumentException] { + part0.getPartition((0, -1)) + } + + intercept[IllegalArgumentException] { + part0.getPartition((0, 7)) + } + + val part1 = GridPartitioner(2, 2, suggestedNumPartitions = 5) + val expected1 = Array( + Array(0, 2), + Array(1, 3)) + for (i <- 0 until 2; j <- 0 until 2) { + assert(part1.getPartition((i, j)) === expected1(i)(j)) + assert(part1.getPartition((i, j, random.nextInt())) === expected1(i)(j)) + } + + val part2 = GridPartitioner(2, 2, suggestedNumPartitions = 5) + assert(part0 !== part2) + assert(part1 === part2) + + val part3 = new GridPartitioner(2, 3, rowsPerPart = 1, colsPerPart = 2) + val expected3 = Array( + Array(0, 0, 2), + Array(1, 1, 3)) + for (i <- 0 until 2; j <- 0 until 3) { + assert(part3.getPartition((i, j)) === expected3(i)(j)) + assert(part3.getPartition((i, j, random.nextInt())) === expected3(i)(j)) + } + + val part4 = GridPartitioner(2, 3, rowsPerPart = 1, colsPerPart = 2) + assert(part3 === part4) + + intercept[IllegalArgumentException] { + new GridPartitioner(2, 2, rowsPerPart = 0, colsPerPart = 1) + } + + intercept[IllegalArgumentException] { + GridPartitioner(2, 2, rowsPerPart = 1, colsPerPart = 0) + } + + intercept[IllegalArgumentException] { + GridPartitioner(2, 2, suggestedNumPartitions = 0) + } + } + + test("toCoordinateMatrix") { + val coordMat = gridBasedMat.toCoordinateMatrix() + assert(coordMat.numRows() === m) + assert(coordMat.numCols() === n) + assert(coordMat.toBreeze() === gridBasedMat.toBreeze()) + } + + test("toIndexedRowMatrix") { + val rowMat = gridBasedMat.toIndexedRowMatrix() + assert(rowMat.numRows() === m) + assert(rowMat.numCols() === n) + assert(rowMat.toBreeze() === gridBasedMat.toBreeze()) + } + + test("toBreeze and toLocalMatrix") { + val expected = BDM( + (1.0, 0.0, 0.0, 0.0), + (0.0, 2.0, 1.0, 0.0), + (3.0, 1.0, 1.0, 0.0), + (0.0, 1.0, 2.0, 1.0), + (0.0, 0.0, 1.0, 5.0)) + + val dense = Matrices.fromBreeze(expected).asInstanceOf[DenseMatrix] + assert(gridBasedMat.toLocalMatrix() === dense) + assert(gridBasedMat.toBreeze() === expected) + } + + test("add") { + val blocks: Seq[((Int, Int), Matrix)] = Seq( + ((0, 0), new DenseMatrix(2, 2, Array(1.0, 0.0, 0.0, 2.0))), + ((0, 1), new DenseMatrix(2, 2, Array(0.0, 1.0, 0.0, 0.0))), + ((1, 0), new DenseMatrix(2, 2, Array(3.0, 0.0, 1.0, 1.0))), + ((1, 1), new DenseMatrix(2, 2, Array(1.0, 2.0, 0.0, 1.0))), + ((2, 0), new DenseMatrix(1, 2, Array(1.0, 0.0))), // Added block that doesn't exist in A + ((2, 1), new DenseMatrix(1, 2, Array(1.0, 5.0)))) + val rdd = sc.parallelize(blocks, numPartitions) + val B = new BlockMatrix(rdd, rowPerPart, colPerPart) + + val expected = BDM( + (2.0, 0.0, 0.0, 0.0), + (0.0, 4.0, 2.0, 0.0), + (6.0, 2.0, 2.0, 0.0), + (0.0, 2.0, 4.0, 2.0), + (1.0, 0.0, 2.0, 10.0)) + + val AplusB = gridBasedMat.add(B) + assert(AplusB.numRows() === m) + assert(AplusB.numCols() === B.numCols()) + assert(AplusB.toBreeze() === expected) + + val C = new BlockMatrix(rdd, rowPerPart, colPerPart, m, n + 1) // columns don't match + intercept[IllegalArgumentException] { + gridBasedMat.add(C) + } + val largerBlocks: Seq[((Int, Int), Matrix)] = Seq( + ((0, 0), new DenseMatrix(4, 4, new Array[Double](16))), + ((1, 0), new DenseMatrix(1, 4, Array(1.0, 0.0, 1.0, 5.0)))) + val C2 = new BlockMatrix(sc.parallelize(largerBlocks, numPartitions), 4, 4, m, n) + intercept[SparkException] { // partitioning doesn't match + gridBasedMat.add(C2) + } + // adding BlockMatrices composed of SparseMatrices + val sparseBlocks = for (i <- 0 until 4) yield ((i / 2, i % 2), SparseMatrix.speye(4)) + val denseBlocks = for (i <- 0 until 4) yield ((i / 2, i % 2), DenseMatrix.eye(4)) + val sparseBM = new BlockMatrix(sc.makeRDD(sparseBlocks, 4), 4, 4, 8, 8) + val denseBM = new BlockMatrix(sc.makeRDD(denseBlocks, 4), 4, 4, 8, 8) + + assert(sparseBM.add(sparseBM).toBreeze() === sparseBM.add(denseBM).toBreeze()) + } + + test("multiply") { + // identity matrix + val blocks: Seq[((Int, Int), Matrix)] = Seq( + ((0, 0), new DenseMatrix(2, 2, Array(1.0, 0.0, 0.0, 1.0))), + ((1, 1), new DenseMatrix(2, 2, Array(1.0, 0.0, 0.0, 1.0)))) + val rdd = sc.parallelize(blocks, 2) + val B = new BlockMatrix(rdd, colPerPart, rowPerPart) + val expected = BDM( + (1.0, 0.0, 0.0, 0.0), + (0.0, 2.0, 1.0, 0.0), + (3.0, 1.0, 1.0, 0.0), + (0.0, 1.0, 2.0, 1.0), + (0.0, 0.0, 1.0, 5.0)) + + val AtimesB = gridBasedMat.multiply(B) + assert(AtimesB.numRows() === m) + assert(AtimesB.numCols() === n) + assert(AtimesB.toBreeze() === expected) + val C = new BlockMatrix(rdd, rowPerPart, colPerPart, m + 1, n) // dimensions don't match + intercept[IllegalArgumentException] { + gridBasedMat.multiply(C) + } + val largerBlocks = Seq(((0, 0), DenseMatrix.eye(4))) + val C2 = new BlockMatrix(sc.parallelize(largerBlocks, numPartitions), 4, 4) + intercept[SparkException] { + // partitioning doesn't match + gridBasedMat.multiply(C2) + } + val rand = new ju.Random(42) + val largerAblocks = for (i <- 0 until 20) yield ((i % 5, i / 5), DenseMatrix.rand(6, 4, rand)) + val largerBblocks = for (i <- 0 until 16) yield ((i % 4, i / 4), DenseMatrix.rand(4, 4, rand)) + + // Try it with increased number of partitions + val largeA = new BlockMatrix(sc.parallelize(largerAblocks, 10), 6, 4) + val largeB = new BlockMatrix(sc.parallelize(largerBblocks, 8), 4, 4) + val largeC = largeA.multiply(largeB) + val localC = largeC.toLocalMatrix() + val result = largeA.toLocalMatrix().multiply(largeB.toLocalMatrix().asInstanceOf[DenseMatrix]) + assert(largeC.numRows() === largeA.numRows()) + assert(largeC.numCols() === largeB.numCols()) + assert(localC ~== result absTol 1e-8) + } + + test("validate") { + // No error + gridBasedMat.validate() + // Wrong MatrixBlock dimensions + val blocks: Seq[((Int, Int), Matrix)] = Seq( + ((0, 0), new DenseMatrix(2, 2, Array(1.0, 0.0, 0.0, 2.0))), + ((0, 1), new DenseMatrix(2, 2, Array(0.0, 1.0, 0.0, 0.0))), + ((1, 0), new DenseMatrix(2, 2, Array(3.0, 0.0, 1.0, 1.0))), + ((1, 1), new DenseMatrix(2, 2, Array(1.0, 2.0, 0.0, 1.0))), + ((2, 1), new DenseMatrix(1, 2, Array(1.0, 5.0)))) + val rdd = sc.parallelize(blocks, numPartitions) + val wrongRowPerParts = new BlockMatrix(rdd, rowPerPart + 1, colPerPart) + val wrongColPerParts = new BlockMatrix(rdd, rowPerPart, colPerPart + 1) + intercept[SparkException] { + wrongRowPerParts.validate() + } + intercept[SparkException] { + wrongColPerParts.validate() + } + // Wrong BlockMatrix dimensions + val wrongRowSize = new BlockMatrix(rdd, rowPerPart, colPerPart, 4, 4) + intercept[AssertionError] { + wrongRowSize.validate() + } + val wrongColSize = new BlockMatrix(rdd, rowPerPart, colPerPart, 5, 2) + intercept[AssertionError] { + wrongColSize.validate() + } + // Duplicate indices + val duplicateBlocks: Seq[((Int, Int), Matrix)] = Seq( + ((0, 0), new DenseMatrix(2, 2, Array(1.0, 0.0, 0.0, 2.0))), + ((0, 0), new DenseMatrix(2, 2, Array(0.0, 1.0, 0.0, 0.0))), + ((1, 1), new DenseMatrix(2, 2, Array(3.0, 0.0, 1.0, 1.0))), + ((1, 1), new DenseMatrix(2, 2, Array(1.0, 2.0, 0.0, 1.0))), + ((2, 1), new DenseMatrix(1, 2, Array(1.0, 5.0)))) + val dupMatrix = new BlockMatrix(sc.parallelize(duplicateBlocks, numPartitions), 2, 2) + intercept[SparkException] { + dupMatrix.validate() + } + } + + test("transpose") { + val expected = BDM( + (1.0, 0.0, 3.0, 0.0, 0.0), + (0.0, 2.0, 1.0, 1.0, 0.0), + (0.0, 1.0, 1.0, 2.0, 1.0), + (0.0, 0.0, 0.0, 1.0, 5.0)) + + val AT = gridBasedMat.transpose + assert(AT.numRows() === gridBasedMat.numCols()) + assert(AT.numCols() === gridBasedMat.numRows()) + assert(AT.toBreeze() === expected) + + // make sure it works when matrices are cached as well + gridBasedMat.cache() + val AT2 = gridBasedMat.transpose + AT2.cache() + assert(AT2.toBreeze() === AT.toBreeze()) + val A = AT2.transpose + assert(A.toBreeze() === gridBasedMat.toBreeze()) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrixSuite.scala index f8709751efce6..04b36a9ef9990 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrixSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrixSuite.scala @@ -73,6 +73,11 @@ class CoordinateMatrixSuite extends FunSuite with MLlibTestSparkContext { assert(mat.toBreeze() === expected) } + test("transpose") { + val transposed = mat.transpose() + assert(mat.toBreeze().t === transposed.toBreeze()) + } + test("toIndexedRowMatrix") { val indexedRowMatrix = mat.toIndexedRowMatrix() val expected = BDM( @@ -95,4 +100,18 @@ class CoordinateMatrixSuite extends FunSuite with MLlibTestSparkContext { Vectors.dense(0.0, 9.0, 0.0, 0.0)) assert(rows === expected) } + + test("toBlockMatrix") { + val blockMat = mat.toBlockMatrix(2, 2) + assert(blockMat.numRows() === m) + assert(blockMat.numCols() === n) + assert(blockMat.toBreeze() === mat.toBreeze()) + + intercept[IllegalArgumentException] { + mat.toBlockMatrix(-1, 2) + } + intercept[IllegalArgumentException] { + mat.toBlockMatrix(2, 0) + } + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala index 741cd4997b853..2ab53cc13db71 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala @@ -80,6 +80,29 @@ class IndexedRowMatrixSuite extends FunSuite with MLlibTestSparkContext { assert(rowMat.rows.collect().toSeq === data.map(_.vector).toSeq) } + test("toCoordinateMatrix") { + val idxRowMat = new IndexedRowMatrix(indexedRows) + val coordMat = idxRowMat.toCoordinateMatrix() + assert(coordMat.numRows() === m) + assert(coordMat.numCols() === n) + assert(coordMat.toBreeze() === idxRowMat.toBreeze()) + } + + test("toBlockMatrix") { + val idxRowMat = new IndexedRowMatrix(indexedRows) + val blockMat = idxRowMat.toBlockMatrix(2, 2) + assert(blockMat.numRows() === m) + assert(blockMat.numCols() === n) + assert(blockMat.toBreeze() === idxRowMat.toBreeze()) + + intercept[IllegalArgumentException] { + idxRowMat.toBlockMatrix(-1, 2) + } + intercept[IllegalArgumentException] { + idxRowMat.toBlockMatrix(2, 0) + } + } + test("multiply a local matrix") { val A = new IndexedRowMatrix(indexedRows) val B = Matrices.dense(3, 2, Array(0.0, 1.0, 2.0, 3.0, 4.0, 5.0)) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala index 681ce9263933b..6d6c0aa5be812 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala @@ -46,22 +46,4 @@ class RDDFunctionsSuite extends FunSuite with MLlibTestSparkContext { val expected = data.flatMap(x => x).sliding(3).toSeq.map(_.toSeq) assert(sliding === expected) } - - test("treeAggregate") { - val rdd = sc.makeRDD(-1000 until 1000, 10) - def seqOp = (c: Long, x: Int) => c + x - def combOp = (c1: Long, c2: Long) => c1 + c2 - for (depth <- 1 until 10) { - val sum = rdd.treeAggregate(0L)(seqOp, combOp, depth) - assert(sum === -1000L) - } - } - - test("treeReduce") { - val rdd = sc.makeRDD(-1000 until 1000, 10) - for (depth <- 1 until 10) { - val sum = rdd.treeReduce(_ + _, depth) - assert(sum === -1000) - } - } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala index f3b7bfda788fa..8775c0ca9df84 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala @@ -24,9 +24,7 @@ import scala.util.Random import org.scalatest.FunSuite import org.jblas.DoubleMatrix -import org.apache.spark.SparkContext._ import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.mllib.recommendation.ALS.BlockStats import org.apache.spark.storage.StorageLevel object ALSSuite { @@ -189,22 +187,6 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext { testALS(100, 200, 2, 15, 0.7, 0.4, false, false, false, -1, -1, false) } - test("analyze one user block and one product block") { - val localRatings = Seq( - Rating(0, 100, 1.0), - Rating(0, 101, 2.0), - Rating(0, 102, 3.0), - Rating(1, 102, 4.0), - Rating(2, 103, 5.0)) - val ratings = sc.makeRDD(localRatings, 2) - val stats = ALS.analyzeBlocks(ratings, 1, 1) - assert(stats.size === 2) - assert(stats(0) === BlockStats("user", 0, 3, 5, 4, 3)) - assert(stats(1) === BlockStats("product", 0, 4, 5, 3, 4)) - } - - // TODO: add tests for analyzing multiple user/product blocks - /** * Test if we can correctly factorize R = U * P where U and P are of known rank. * @@ -215,7 +197,7 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext { * @param samplingRate what fraction of the user-product pairs are known * @param matchThreshold max difference allowed to consider a predicted rating correct * @param implicitPrefs flag to test implicit feedback - * @param bulkPredict flag to test bulk prediciton + * @param bulkPredict flag to test bulk predicition * @param negativeWeights whether the generated data can contain negative values * @param numUserBlocks number of user blocks to partition users into * @param numProductBlocks number of product blocks to partition products into diff --git a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModelSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModelSuite.scala index b9caecc904a23..9801e87576744 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModelSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModelSuite.scala @@ -22,6 +22,7 @@ import org.scalatest.FunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.rdd.RDD +import org.apache.spark.util.Utils class MatrixFactorizationModelSuite extends FunSuite with MLlibTestSparkContext { @@ -53,4 +54,22 @@ class MatrixFactorizationModelSuite extends FunSuite with MLlibTestSparkContext new MatrixFactorizationModel(rank, userFeatures, prodFeatures1) } } + + test("save/load") { + val model = new MatrixFactorizationModel(rank, userFeatures, prodFeatures) + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + def collect(features: RDD[(Int, Array[Double])]): Set[(Int, Seq[Double])] = { + features.mapValues(_.toSeq).collect().toSet + } + try { + model.save(sc, path) + val newModel = MatrixFactorizationModel.load(sc, path) + assert(newModel.rank === rank) + assert(collect(newModel.userFeatures) === collect(userFeatures)) + assert(collect(newModel.productFeatures) === collect(prodFeatures)) + } finally { + Utils.deleteRecursively(tempDir) + } + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/IsotonicRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/IsotonicRegressionSuite.scala new file mode 100644 index 0000000000000..7ef45248281e9 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/IsotonicRegressionSuite.scala @@ -0,0 +1,241 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.regression + +import org.scalatest.{Matchers, FunSuite} + +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ + +class IsotonicRegressionSuite extends FunSuite with MLlibTestSparkContext with Matchers { + + private def round(d: Double) = { + Math.round(d * 100).toDouble / 100 + } + + private def generateIsotonicInput(labels: Seq[Double]): Seq[(Double, Double, Double)] = { + Seq.tabulate(labels.size)(i => (labels(i), i.toDouble, 1d)) + } + + private def generateIsotonicInput( + labels: Seq[Double], + weights: Seq[Double]): Seq[(Double, Double, Double)] = { + Seq.tabulate(labels.size)(i => (labels(i), i.toDouble, weights(i))) + } + + private def runIsotonicRegression( + labels: Seq[Double], + weights: Seq[Double], + isotonic: Boolean): IsotonicRegressionModel = { + val trainRDD = sc.parallelize(generateIsotonicInput(labels, weights)).cache() + new IsotonicRegression().setIsotonic(isotonic).run(trainRDD) + } + + private def runIsotonicRegression( + labels: Seq[Double], + isotonic: Boolean): IsotonicRegressionModel = { + runIsotonicRegression(labels, Array.fill(labels.size)(1d), isotonic) + } + + test("increasing isotonic regression") { + /* + The following result could be re-produced with sklearn. + + > from sklearn.isotonic import IsotonicRegression + > x = range(9) + > y = [1, 2, 3, 1, 6, 17, 16, 17, 18] + > ir = IsotonicRegression(x, y) + > print ir.predict(x) + + array([ 1. , 2. , 2. , 2. , 6. , 16.5, 16.5, 17. , 18. ]) + */ + val model = runIsotonicRegression(Seq(1, 2, 3, 1, 6, 17, 16, 17, 18), true) + + assert(Array.tabulate(9)(x => model.predict(x)) === Array(1, 2, 2, 2, 6, 16.5, 16.5, 17, 18)) + + assert(model.boundaries === Array(0, 1, 3, 4, 5, 6, 7, 8)) + assert(model.predictions === Array(1, 2, 2, 6, 16.5, 16.5, 17.0, 18.0)) + assert(model.isotonic) + } + + test("isotonic regression with size 0") { + val model = runIsotonicRegression(Seq(), true) + + assert(model.predictions === Array()) + } + + test("isotonic regression with size 1") { + val model = runIsotonicRegression(Seq(1), true) + + assert(model.predictions === Array(1.0)) + } + + test("isotonic regression strictly increasing sequence") { + val model = runIsotonicRegression(Seq(1, 2, 3, 4, 5), true) + + assert(model.predictions === Array(1, 2, 3, 4, 5)) + } + + test("isotonic regression strictly decreasing sequence") { + val model = runIsotonicRegression(Seq(5, 4, 3, 2, 1), true) + + assert(model.boundaries === Array(0, 4)) + assert(model.predictions === Array(3, 3)) + } + + test("isotonic regression with last element violating monotonicity") { + val model = runIsotonicRegression(Seq(1, 2, 3, 4, 2), true) + + assert(model.boundaries === Array(0, 1, 2, 4)) + assert(model.predictions === Array(1, 2, 3, 3)) + } + + test("isotonic regression with first element violating monotonicity") { + val model = runIsotonicRegression(Seq(4, 2, 3, 4, 5), true) + + assert(model.boundaries === Array(0, 2, 3, 4)) + assert(model.predictions === Array(3, 3, 4, 5)) + } + + test("isotonic regression with negative labels") { + val model = runIsotonicRegression(Seq(-1, -2, 0, 1, -1), true) + + assert(model.boundaries === Array(0, 1, 2, 4)) + assert(model.predictions === Array(-1.5, -1.5, 0, 0)) + } + + test("isotonic regression with unordered input") { + val trainRDD = sc.parallelize(generateIsotonicInput(Seq(1, 2, 3, 4, 5)).reverse, 2).cache() + + val model = new IsotonicRegression().run(trainRDD) + assert(model.predictions === Array(1, 2, 3, 4, 5)) + } + + test("weighted isotonic regression") { + val model = runIsotonicRegression(Seq(1, 2, 3, 4, 2), Seq(1, 1, 1, 1, 2), true) + + assert(model.boundaries === Array(0, 1, 2, 4)) + assert(model.predictions === Array(1, 2, 2.75, 2.75)) + } + + test("weighted isotonic regression with weights lower than 1") { + val model = runIsotonicRegression(Seq(1, 2, 3, 2, 1), Seq(1, 1, 1, 0.1, 0.1), true) + + assert(model.boundaries === Array(0, 1, 2, 4)) + assert(model.predictions.map(round) === Array(1, 2, 3.3/1.2, 3.3/1.2)) + } + + test("weighted isotonic regression with negative weights") { + val model = runIsotonicRegression(Seq(1, 2, 3, 2, 1), Seq(-1, 1, -3, 1, -5), true) + + assert(model.boundaries === Array(0.0, 1.0, 4.0)) + assert(model.predictions === Array(1.0, 10.0/6, 10.0/6)) + } + + test("weighted isotonic regression with zero weights") { + val model = runIsotonicRegression(Seq[Double](1, 2, 3, 2, 1), Seq[Double](0, 0, 0, 1, 0), true) + + assert(model.boundaries === Array(0.0, 1.0, 4.0)) + assert(model.predictions === Array(1, 2, 2)) + } + + test("isotonic regression prediction") { + val model = runIsotonicRegression(Seq(1, 2, 7, 1, 2), true) + + assert(model.predict(-2) === 1) + assert(model.predict(-1) === 1) + assert(model.predict(0.5) === 1.5) + assert(model.predict(0.75) === 1.75) + assert(model.predict(1) === 2) + assert(model.predict(2) === 10d/3) + assert(model.predict(9) === 10d/3) + } + + test("isotonic regression prediction with duplicate features") { + val trainRDD = sc.parallelize( + Seq[(Double, Double, Double)]( + (2, 1, 1), (1, 1, 1), (4, 2, 1), (2, 2, 1), (6, 3, 1), (5, 3, 1)), 2).cache() + val model = new IsotonicRegression().run(trainRDD) + + assert(model.predict(0) === 1) + assert(model.predict(1.5) === 2) + assert(model.predict(2.5) === 4.5) + assert(model.predict(4) === 6) + } + + test("antitonic regression prediction with duplicate features") { + val trainRDD = sc.parallelize( + Seq[(Double, Double, Double)]( + (5, 1, 1), (6, 1, 1), (2, 2, 1), (4, 2, 1), (1, 3, 1), (2, 3, 1)), 2).cache() + val model = new IsotonicRegression().setIsotonic(false).run(trainRDD) + + assert(model.predict(0) === 6) + assert(model.predict(1.5) === 4.5) + assert(model.predict(2.5) === 2) + assert(model.predict(4) === 1) + } + + test("isotonic regression RDD prediction") { + val model = runIsotonicRegression(Seq(1, 2, 7, 1, 2), true) + + val testRDD = sc.parallelize(List(-2.0, -1.0, 0.5, 0.75, 1.0, 2.0, 9.0), 2).cache() + val predictions = testRDD.map(x => (x, model.predict(x))).collect().sortBy(_._1).map(_._2) + assert(predictions === Array(1, 1, 1.5, 1.75, 2, 10.0/3, 10.0/3)) + } + + test("antitonic regression prediction") { + val model = runIsotonicRegression(Seq(7, 5, 3, 5, 1), false) + + assert(model.predict(-2) === 7) + assert(model.predict(-1) === 7) + assert(model.predict(0.5) === 6) + assert(model.predict(0.75) === 5.5) + assert(model.predict(1) === 5) + assert(model.predict(2) === 4) + assert(model.predict(9) === 1) + } + + test("model construction") { + val model = new IsotonicRegressionModel(Array(0.0, 1.0), Array(1.0, 2.0), isotonic = true) + assert(model.predict(-0.5) === 1.0) + assert(model.predict(0.0) === 1.0) + assert(model.predict(0.5) ~== 1.5 absTol 1e-14) + assert(model.predict(1.0) === 2.0) + assert(model.predict(1.5) === 2.0) + + intercept[IllegalArgumentException] { + // different array sizes. + new IsotonicRegressionModel(Array(0.0, 1.0), Array(1.0), isotonic = true) + } + + intercept[IllegalArgumentException] { + // unordered boundaries + new IsotonicRegressionModel(Array(1.0, 0.0), Array(1.0, 2.0), isotonic = true) + } + + intercept[IllegalArgumentException] { + // unordered predictions (isotonic) + new IsotonicRegressionModel(Array(0.0, 1.0), Array(2.0, 1.0), isotonic = true) + } + + intercept[IllegalArgumentException] { + // unordered predictions (antitonic) + new IsotonicRegressionModel(Array(0.0, 1.0), Array(1.0, 2.0), isotonic = false) + } + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala index 2668dcc14a842..c9f5dc069ef2e 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala @@ -24,6 +24,13 @@ import org.scalatest.FunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.{LocalClusterSparkContext, LinearDataGenerator, MLlibTestSparkContext} +import org.apache.spark.util.Utils + +private object LassoSuite { + + /** 3 features */ + val model = new LassoModel(weights = Vectors.dense(0.1, 0.2, 0.3), intercept = 0.5) +} class LassoSuite extends FunSuite with MLlibTestSparkContext { @@ -115,6 +122,23 @@ class LassoSuite extends FunSuite with MLlibTestSparkContext { // Test prediction on Array. validatePrediction(validationData.map(row => model.predict(row.features)), validationData) } + + test("model save/load") { + val model = LassoSuite.model + + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + + // Save model, load it back, and compare. + try { + model.save(sc, path) + val sameModel = LassoModel.load(sc, path) + assert(model.weights == sameModel.weights) + assert(model.intercept == sameModel.intercept) + } finally { + Utils.deleteRecursively(tempDir) + } + } } class LassoClusterSuite extends FunSuite with LocalClusterSparkContext { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala index 864622a9296a6..3781931c2f819 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala @@ -24,6 +24,13 @@ import org.scalatest.FunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.{LocalClusterSparkContext, LinearDataGenerator, MLlibTestSparkContext} +import org.apache.spark.util.Utils + +private object LinearRegressionSuite { + + /** 3 features */ + val model = new LinearRegressionModel(weights = Vectors.dense(0.1, 0.2, 0.3), intercept = 0.5) +} class LinearRegressionSuite extends FunSuite with MLlibTestSparkContext { @@ -124,6 +131,23 @@ class LinearRegressionSuite extends FunSuite with MLlibTestSparkContext { validatePrediction( sparseValidationData.map(row => model.predict(row.features)), sparseValidationData) } + + test("model save/load") { + val model = LinearRegressionSuite.model + + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + + // Save model, load it back, and compare. + try { + model.save(sc, path) + val sameModel = LinearRegressionModel.load(sc, path) + assert(model.weights == sameModel.weights) + assert(model.intercept == sameModel.intercept) + } finally { + Utils.deleteRecursively(tempDir) + } + } } class LinearRegressionClusterSuite extends FunSuite with LocalClusterSparkContext { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala index 18d3bf5ea4eca..43d61151e2471 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala @@ -25,6 +25,13 @@ import org.scalatest.FunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.{LocalClusterSparkContext, LinearDataGenerator, MLlibTestSparkContext} +import org.apache.spark.util.Utils + +private object RidgeRegressionSuite { + + /** 3 features */ + val model = new RidgeRegressionModel(weights = Vectors.dense(0.1, 0.2, 0.3), intercept = 0.5) +} class RidgeRegressionSuite extends FunSuite with MLlibTestSparkContext { @@ -75,6 +82,23 @@ class RidgeRegressionSuite extends FunSuite with MLlibTestSparkContext { assert(ridgeErr < linearErr, "ridgeError (" + ridgeErr + ") was not less than linearError(" + linearErr + ")") } + + test("model save/load") { + val model = RidgeRegressionSuite.model + + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + + // Save model, load it back, and compare. + try { + model.save(sc, path) + val sameModel = RidgeRegressionModel.load(sc, path) + assert(model.weights == sameModel.weights) + assert(model.intercept == sameModel.intercept) + } finally { + Utils.deleteRecursively(tempDir) + } + } } class RidgeRegressionClusterSuite extends FunSuite with LocalClusterSparkContext { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/KernelDensitySuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/KernelDensitySuite.scala new file mode 100644 index 0000000000000..f6a1e19f50296 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/KernelDensitySuite.scala @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.stat + +import org.scalatest.FunSuite + +import org.apache.commons.math3.distribution.NormalDistribution + +import org.apache.spark.mllib.util.LocalClusterSparkContext + +class KernelDensitySuite extends FunSuite with LocalClusterSparkContext { + test("kernel density single sample") { + val rdd = sc.parallelize(Array(5.0)) + val evaluationPoints = Array(5.0, 6.0) + val densities = KernelDensity.estimate(rdd, 3.0, evaluationPoints) + val normal = new NormalDistribution(5.0, 3.0) + val acceptableErr = 1e-6 + assert(densities(0) - normal.density(5.0) < acceptableErr) + assert(densities(0) - normal.density(6.0) < acceptableErr) + } + + test("kernel density multiple samples") { + val rdd = sc.parallelize(Array(5.0, 10.0)) + val evaluationPoints = Array(5.0, 6.0) + val densities = KernelDensity.estimate(rdd, 3.0, evaluationPoints) + val normal1 = new NormalDistribution(5.0, 3.0) + val normal2 = new NormalDistribution(10.0, 3.0) + val acceptableErr = 1e-6 + assert(densities(0) - (normal1.density(5.0) + normal2.density(5.0)) / 2 < acceptableErr) + assert(densities(0) - (normal1.density(6.0) + normal2.density(6.0)) / 2 < acceptableErr) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 9347eaf9221a8..4c162df810bb2 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -29,8 +29,10 @@ import org.apache.spark.mllib.tree.configuration.FeatureType._ import org.apache.spark.mllib.tree.configuration.{QuantileStrategy, Strategy} import org.apache.spark.mllib.tree.impl.{BaggedPoint, DecisionTreeMetadata, TreePoint} import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance} -import org.apache.spark.mllib.tree.model.{InformationGainStats, DecisionTreeModel, Node} +import org.apache.spark.mllib.tree.model._ import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.util.Utils + class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext { @@ -188,7 +190,7 @@ class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext { assert(splits.length === 2) assert(bins.length === 2) assert(splits(0).length === 3) - assert(bins(0).length === 6) + assert(bins(0).length === 0) // Expecting 2^2 - 1 = 3 bins/splits assert(splits(0)(0).feature === 0) @@ -226,41 +228,6 @@ class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext { assert(splits(1)(2).categories.contains(0.0)) assert(splits(1)(2).categories.contains(1.0)) - // Check bins. - - assert(bins(0)(0).category === Double.MinValue) - assert(bins(0)(0).lowSplit.categories.length === 0) - assert(bins(0)(0).highSplit.categories.length === 1) - assert(bins(0)(0).highSplit.categories.contains(0.0)) - assert(bins(1)(0).category === Double.MinValue) - assert(bins(1)(0).lowSplit.categories.length === 0) - assert(bins(1)(0).highSplit.categories.length === 1) - assert(bins(1)(0).highSplit.categories.contains(0.0)) - - assert(bins(0)(1).category === Double.MinValue) - assert(bins(0)(1).lowSplit.categories.length === 1) - assert(bins(0)(1).lowSplit.categories.contains(0.0)) - assert(bins(0)(1).highSplit.categories.length === 1) - assert(bins(0)(1).highSplit.categories.contains(1.0)) - assert(bins(1)(1).category === Double.MinValue) - assert(bins(1)(1).lowSplit.categories.length === 1) - assert(bins(1)(1).lowSplit.categories.contains(0.0)) - assert(bins(1)(1).highSplit.categories.length === 1) - assert(bins(1)(1).highSplit.categories.contains(1.0)) - - assert(bins(0)(2).category === Double.MinValue) - assert(bins(0)(2).lowSplit.categories.length === 1) - assert(bins(0)(2).lowSplit.categories.contains(1.0)) - assert(bins(0)(2).highSplit.categories.length === 2) - assert(bins(0)(2).highSplit.categories.contains(1.0)) - assert(bins(0)(2).highSplit.categories.contains(0.0)) - assert(bins(1)(2).category === Double.MinValue) - assert(bins(1)(2).lowSplit.categories.length === 1) - assert(bins(1)(2).lowSplit.categories.contains(1.0)) - assert(bins(1)(2).highSplit.categories.length === 2) - assert(bins(1)(2).highSplit.categories.contains(1.0)) - assert(bins(1)(2).highSplit.categories.contains(0.0)) - } test("Multiclass classification with ordered categorical features: split and bin calculations") { @@ -857,9 +824,32 @@ class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext { assert(topNode.leftNode.get.impurity === 0.0) assert(topNode.rightNode.get.impurity === 0.0) } + + test("Node.subtreeIterator") { + val model = DecisionTreeSuite.createModel(Classification) + val nodeIds = model.topNode.subtreeIterator.map(_.id).toArray.sorted + assert(nodeIds === DecisionTreeSuite.createdModelNodeIds) + } + + test("model save/load") { + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + + Array(Classification, Regression).foreach { algo => + val model = DecisionTreeSuite.createModel(algo) + // Save model, load it back, and compare. + try { + model.save(sc, path) + val sameModel = DecisionTreeModel.load(sc, path) + DecisionTreeSuite.checkEqual(model, sameModel) + } finally { + Utils.deleteRecursively(tempDir) + } + } + } } -object DecisionTreeSuite { +object DecisionTreeSuite extends FunSuite { def validateClassifier( model: DecisionTreeModel, @@ -979,4 +969,95 @@ object DecisionTreeSuite { arr } + /** Create a leaf node with the given node ID */ + private def createLeafNode(id: Int): Node = { + Node(nodeIndex = id, new Predict(0.0, 1.0), impurity = 0.5, isLeaf = true) + } + + /** + * Create an internal node with the given node ID and feature type. + * Note: This does NOT set the child nodes. + */ + private def createInternalNode(id: Int, featureType: FeatureType): Node = { + val node = Node(nodeIndex = id, new Predict(0.0, 1.0), impurity = 0.5, isLeaf = false) + featureType match { + case Continuous => + node.split = Some(new Split(feature = 0, threshold = 0.5, Continuous, + categories = List.empty[Double])) + case Categorical => + node.split = Some(new Split(feature = 1, threshold = 0.0, Categorical, + categories = List(0.0, 1.0))) + } + // TODO: The information gain stats should be consistent with the same info stored in children. + node.stats = Some(new InformationGainStats(gain = 0.1, impurity = 0.2, + leftImpurity = 0.3, rightImpurity = 0.4, new Predict(1.0, 0.4), new Predict(0.0, 0.6))) + node + } + + /** + * Create a tree model. This is deterministic and contains a variety of node and feature types. + */ + private[tree] def createModel(algo: Algo): DecisionTreeModel = { + val topNode = createInternalNode(id = 1, Continuous) + val (node2, node3) = (createLeafNode(id = 2), createInternalNode(id = 3, Categorical)) + val (node6, node7) = (createLeafNode(id = 6), createLeafNode(id = 7)) + topNode.leftNode = Some(node2) + topNode.rightNode = Some(node3) + node3.leftNode = Some(node6) + node3.rightNode = Some(node7) + new DecisionTreeModel(topNode, algo) + } + + /** Sorted Node IDs matching the model returned by [[createModel()]] */ + private val createdModelNodeIds = Array(1, 2, 3, 6, 7) + + /** + * Check if the two trees are exactly the same. + * Note: I hesitate to override Node.equals since it could cause problems if users + * make mistakes such as creating loops of Nodes. + * If the trees are not equal, this prints the two trees and throws an exception. + */ + private[tree] def checkEqual(a: DecisionTreeModel, b: DecisionTreeModel): Unit = { + try { + assert(a.algo === b.algo) + checkEqual(a.topNode, b.topNode) + } catch { + case ex: Exception => + throw new AssertionError("checkEqual failed since the two trees were not identical.\n" + + "TREE A:\n" + a.toDebugString + "\n" + + "TREE B:\n" + b.toDebugString + "\n", ex) + } + } + + /** + * Return true iff the two nodes and their descendents are exactly the same. + * Note: I hesitate to override Node.equals since it could cause problems if users + * make mistakes such as creating loops of Nodes. + */ + private def checkEqual(a: Node, b: Node): Unit = { + assert(a.id === b.id) + assert(a.predict === b.predict) + assert(a.impurity === b.impurity) + assert(a.isLeaf === b.isLeaf) + assert(a.split === b.split) + (a.stats, b.stats) match { + // TODO: Check other fields besides the infomation gain. + case (Some(aStats), Some(bStats)) => assert(aStats.gain === bStats.gain) + case (None, None) => + case _ => throw new AssertionError( + s"Only one instance has stats defined. (a.stats: ${a.stats}, b.stats: ${b.stats})") + } + (a.leftNode, b.leftNode) match { + case (Some(aNode), Some(bNode)) => checkEqual(aNode, bNode) + case (None, None) => + case _ => throw new AssertionError("Only one instance has leftNode defined. " + + s"(a.leftNode: ${a.leftNode}, b.leftNode: ${b.leftNode})") + } + (a.rightNode, b.rightNode) match { + case (Some(aNode: Node), Some(bNode: Node)) => checkEqual(aNode, bNode) + case (None, None) => + case _ => throw new AssertionError("Only one instance has rightNode defined. " + + s"(a.rightNode: ${a.rightNode}, b.rightNode: ${b.rightNode})") + } + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala index 3aa97e544680b..b437aeaaf0547 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala @@ -24,8 +24,10 @@ import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.{BoostingStrategy, Strategy} import org.apache.spark.mllib.tree.impurity.Variance import org.apache.spark.mllib.tree.loss.{AbsoluteError, SquaredError, LogLoss} - +import org.apache.spark.mllib.tree.model.GradientBoostedTreesModel import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.util.Utils + /** * Test suite for [[GradientBoostedTrees]]. @@ -35,32 +37,30 @@ class GradientBoostedTreesSuite extends FunSuite with MLlibTestSparkContext { test("Regression with continuous features: SquaredError") { GradientBoostedTreesSuite.testCombinations.foreach { case (numIterations, learningRate, subsamplingRate) => - GradientBoostedTreesSuite.randomSeeds.foreach { randomSeed => - val rdd = sc.parallelize(GradientBoostedTreesSuite.data, 2) - - val treeStrategy = new Strategy(algo = Regression, impurity = Variance, maxDepth = 2, - categoricalFeaturesInfo = Map.empty, subsamplingRate = subsamplingRate) - val boostingStrategy = - new BoostingStrategy(treeStrategy, SquaredError, numIterations, learningRate) - - val gbt = GradientBoostedTrees.train(rdd, boostingStrategy) - - assert(gbt.trees.size === numIterations) - try { - EnsembleTestHelper.validateRegressor(gbt, GradientBoostedTreesSuite.data, 0.06) - } catch { - case e: java.lang.AssertionError => - println(s"FAILED for numIterations=$numIterations, learningRate=$learningRate," + - s" subsamplingRate=$subsamplingRate") - throw e - } + val rdd = sc.parallelize(GradientBoostedTreesSuite.data, 2) + + val treeStrategy = new Strategy(algo = Regression, impurity = Variance, maxDepth = 2, + categoricalFeaturesInfo = Map.empty, subsamplingRate = subsamplingRate) + val boostingStrategy = + new BoostingStrategy(treeStrategy, SquaredError, numIterations, learningRate) - val remappedInput = rdd.map(x => new LabeledPoint((x.label * 2) - 1, x.features)) - val dt = DecisionTree.train(remappedInput, treeStrategy) + val gbt = GradientBoostedTrees.train(rdd, boostingStrategy) - // Make sure trees are the same. - assert(gbt.trees.head.toString == dt.toString) + assert(gbt.trees.size === numIterations) + try { + EnsembleTestHelper.validateRegressor(gbt, GradientBoostedTreesSuite.data, 0.06) + } catch { + case e: java.lang.AssertionError => + println(s"FAILED for numIterations=$numIterations, learningRate=$learningRate," + + s" subsamplingRate=$subsamplingRate") + throw e } + + val remappedInput = rdd.map(x => new LabeledPoint((x.label * 2) - 1, x.features)) + val dt = DecisionTree.train(remappedInput, treeStrategy) + + // Make sure trees are the same. + assert(gbt.trees.head.toString == dt.toString) } } @@ -128,14 +128,78 @@ class GradientBoostedTreesSuite extends FunSuite with MLlibTestSparkContext { } } + test("SPARK-5496: BoostingStrategy.defaultParams should recognize Classification") { + for (algo <- Seq("classification", "Classification", "regression", "Regression")) { + BoostingStrategy.defaultParams(algo) + } + } + + test("model save/load") { + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + + val trees = Range(0, 3).map(_ => DecisionTreeSuite.createModel(Regression)).toArray + val treeWeights = Array(0.1, 0.3, 1.1) + + Array(Classification, Regression).foreach { algo => + val model = new GradientBoostedTreesModel(algo, trees, treeWeights) + + // Save model, load it back, and compare. + try { + model.save(sc, path) + val sameModel = GradientBoostedTreesModel.load(sc, path) + assert(model.algo == sameModel.algo) + model.trees.zip(sameModel.trees).foreach { case (treeA, treeB) => + DecisionTreeSuite.checkEqual(treeA, treeB) + } + assert(model.treeWeights === sameModel.treeWeights) + } finally { + Utils.deleteRecursively(tempDir) + } + } + } + + test("runWithValidation stops early and performs better on a validation dataset") { + // Set numIterations large enough so that it stops early. + val numIterations = 20 + val trainRdd = sc.parallelize(GradientBoostedTreesSuite.trainData, 2) + val validateRdd = sc.parallelize(GradientBoostedTreesSuite.validateData, 2) + + val algos = Array(Regression, Regression, Classification) + val losses = Array(SquaredError, AbsoluteError, LogLoss) + (algos zip losses) map { + case (algo, loss) => { + val treeStrategy = new Strategy(algo = algo, impurity = Variance, maxDepth = 2, + categoricalFeaturesInfo = Map.empty) + val boostingStrategy = + new BoostingStrategy(treeStrategy, loss, numIterations, validationTol = 0.0) + val gbtValidate = new GradientBoostedTrees(boostingStrategy) + .runWithValidation(trainRdd, validateRdd) + assert(gbtValidate.numTrees !== numIterations) + + // Test that it performs better on the validation dataset. + val gbt = GradientBoostedTrees.train(trainRdd, boostingStrategy) + val (errorWithoutValidation, errorWithValidation) = { + if (algo == Classification) { + val remappedRdd = validateRdd.map(x => new LabeledPoint(2 * x.label - 1, x.features)) + (loss.computeError(gbt, remappedRdd), loss.computeError(gbtValidate, remappedRdd)) + } else { + (loss.computeError(gbt, validateRdd), loss.computeError(gbtValidate, validateRdd)) + } + } + assert(errorWithValidation <= errorWithoutValidation) + } + } + } + } -object GradientBoostedTreesSuite { +private object GradientBoostedTreesSuite { // Combinations for estimators, learning rates and subsamplingRate val testCombinations = Array((10, 1.0, 1.0), (10, 0.1, 1.0), (10, 0.5, 0.75), (10, 0.1, 0.75)) - val randomSeeds = Array(681283, 4398) - val data = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 10, 100) + val trainData = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 20, 120) + val validateData = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 20, 80) } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/ImpuritySuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/ImpuritySuite.scala new file mode 100644 index 0000000000000..92b498580af03 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/ImpuritySuite.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree + +import org.scalatest.FunSuite + +import org.apache.spark.mllib.tree.impurity.{EntropyAggregator, GiniAggregator} +import org.apache.spark.mllib.util.MLlibTestSparkContext + +/** + * Test suites for [[GiniAggregator]] and [[EntropyAggregator]]. + */ +class ImpuritySuite extends FunSuite with MLlibTestSparkContext { + test("Gini impurity does not support negative labels") { + val gini = new GiniAggregator(2) + intercept[IllegalArgumentException] { + gini.update(Array(0.0, 1.0, 2.0), 0, -1, 0.0) + } + } + + test("Entropy does not support negative labels") { + val entropy = new EntropyAggregator(2) + intercept[IllegalArgumentException] { + entropy.update(Array(0.0, 1.0, 2.0), 0, -1, 0.0) + } + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala index f7f0f20c6c125..ee3bc98486862 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala @@ -27,8 +27,10 @@ import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.Strategy import org.apache.spark.mllib.tree.impl.DecisionTreeMetadata import org.apache.spark.mllib.tree.impurity.{Gini, Variance} -import org.apache.spark.mllib.tree.model.Node +import org.apache.spark.mllib.tree.model.{Node, RandomForestModel} import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.util.Utils + /** * Test suite for [[RandomForest]]. @@ -196,6 +198,42 @@ class RandomForestSuite extends FunSuite with MLlibTestSparkContext { featureSubsetStrategy = "sqrt", seed = 12345) EnsembleTestHelper.validateClassifier(model, arr, 1.0) } -} + test("subsampling rate in RandomForest"){ + val arr = EnsembleTestHelper.generateOrderedLabeledPoints(5, 20) + val rdd = sc.parallelize(arr) + val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2, + numClasses = 2, categoricalFeaturesInfo = Map.empty[Int, Int], + useNodeIdCache = true) + val rf1 = RandomForest.trainClassifier(rdd, strategy, numTrees = 3, + featureSubsetStrategy = "auto", seed = 123) + strategy.subsamplingRate = 0.5 + val rf2 = RandomForest.trainClassifier(rdd, strategy, numTrees = 3, + featureSubsetStrategy = "auto", seed = 123) + assert(rf1.toDebugString != rf2.toDebugString) + } + + test("model save/load") { + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + + Array(Classification, Regression).foreach { algo => + val trees = Range(0, 3).map(_ => DecisionTreeSuite.createModel(algo)).toArray + val model = new RandomForestModel(algo, trees) + + // Save model, load it back, and compare. + try { + model.save(sc, path) + val sameModel = RandomForestModel.load(sc, path) + assert(model.algo == sameModel.algo) + model.trees.zip(sameModel.trees).foreach { case (treeA, treeB) => + DecisionTreeSuite.checkEqual(treeA, treeB) + } + } finally { + Utils.deleteRecursively(tempDir) + } + } + } + +} diff --git a/network/common/pom.xml b/network/common/pom.xml index 245a96b8c4038..8f7c924d6b3a3 100644 --- a/network/common/pom.xml +++ b/network/common/pom.xml @@ -48,10 +48,15 @@ slf4j-api provided + com.google.guava guava - provided + compile @@ -87,11 +92,6 @@ maven-jar-plugin 2.2 - - - test-jar - - test-jar-on-test-compile test-compile diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java b/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java index 625c3257d764e..ef209991804b4 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java +++ b/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java @@ -100,8 +100,7 @@ protected void initChannel(SocketChannel ch) throws Exception { } }); - channelFuture = bootstrap.bind(new InetSocketAddress(portToBind)); - channelFuture.syncUninterruptibly(); + bindRightPort(portToBind); port = ((InetSocketAddress) channelFuture.channel().localAddress()).getPort(); logger.debug("Shuffle server started on port :" + port); @@ -123,4 +122,37 @@ public void close() { bootstrap = null; } + /** + * Attempt to bind to the specified port up to a fixed number of retries. + * If all attempts fail after the max number of retries, exit. + */ + private void bindRightPort(int portToBind) { + int maxPortRetries = conf.portMaxRetries(); + + for (int i = 0; i <= maxPortRetries; i++) { + int tryPort = -1; + if (0 == portToBind) { + // Do not increment port if tryPort is 0, which is treated as a special port + tryPort = 0; + } else { + // If the new port wraps around, do not try a privilege port + tryPort = ((portToBind + i - 1024) % (65536 - 1024)) + 1024; + } + try { + channelFuture = bootstrap.bind(new InetSocketAddress(tryPort)); + channelFuture.syncUninterruptibly(); + return; + } catch (Exception e) { + logger.warn("Netty service could not bind on port " + tryPort + + ". Attempting the next port."); + if (i >= maxPortRetries) { + logger.error(e.getMessage() + ": Netty server failed after " + + maxPortRetries + " retries."); + + // If it can't find a right port, it should exit directly. + System.exit(-1); + } + } + } + } } diff --git a/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java b/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java index 6c9178688693f..2eaf3b71d9a49 100644 --- a/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java +++ b/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java @@ -98,4 +98,11 @@ public int memoryMapBytes() { public boolean lazyFileDescriptor() { return conf.getBoolean("spark.shuffle.io.lazyFD", true); } + + /** + * Maximum number of retries when binding to a port before giving up. + */ + public int portMaxRetries() { + return conf.getInt("spark.port.maxRetries", 16); + } } diff --git a/network/shuffle/pom.xml b/network/shuffle/pom.xml index 5bfa1ac9c373e..c2d0300ecd904 100644 --- a/network/shuffle/pom.xml +++ b/network/shuffle/pom.xml @@ -52,7 +52,6 @@ com.google.guava guava - provided diff --git a/network/yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java b/network/yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java index a34aabe9e78a6..63b21222e7b77 100644 --- a/network/yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java +++ b/network/yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java @@ -76,6 +76,9 @@ public class YarnShuffleService extends AuxiliaryService { // The actual server that serves shuffle files private TransportServer shuffleServer = null; + // Handles registering executors and opening shuffle blocks + private ExternalShuffleBlockHandler blockHandler; + public YarnShuffleService() { super("spark_shuffle"); logger.info("Initializing YARN shuffle service for Spark"); @@ -99,7 +102,8 @@ protected void serviceInit(Configuration conf) { // If authentication is enabled, set up the shuffle server to use a // special RPC handler that filters out unauthenticated fetch requests boolean authEnabled = conf.getBoolean(SPARK_AUTHENTICATE_KEY, DEFAULT_SPARK_AUTHENTICATE); - RpcHandler rpcHandler = new ExternalShuffleBlockHandler(transportConf); + blockHandler = new ExternalShuffleBlockHandler(transportConf); + RpcHandler rpcHandler = blockHandler; if (authEnabled) { secretManager = new ShuffleSecretManager(); rpcHandler = new SaslRpcHandler(rpcHandler, secretManager); @@ -136,6 +140,7 @@ public void stopApplication(ApplicationTerminationContext context) { if (isAuthenticationEnabled()) { secretManager.unregisterApp(appId); } + blockHandler.applicationRemoved(appId, false /* clean up local dirs */); } catch (Exception e) { logger.error("Exception when stopping application {}", appId, e); } diff --git a/pom.xml b/pom.xml index f4466e56c2a53..bb355bf735bee 100644 --- a/pom.xml +++ b/pom.xml @@ -117,7 +117,7 @@ 2.0.1 0.21.0 shaded-protobuf - 1.7.5 + 1.7.10 1.2.17 1.0.4 2.4.1 @@ -135,8 +135,11 @@ 1.6.0rc3 1.2.3 8.1.14.v20131031 + 3.0.0.v201112011016 0.5.0 - 3.0.0 + 2.4.0 + 2.0.8 + 3.1.0 1.7.6 0.7.1 @@ -149,7 +152,9 @@ 2.10 ${scala.version} org.scala-lang + 3.6.3 1.8.8 + 2.4.4 1.1.1.6 + + org.eclipse.jetty + jetty-http + ${jetty.version} + provided + + + org.eclipse.jetty + jetty-continuation + ${jetty.version} + provided + + + org.eclipse.jetty + jetty-servlet + ${jetty.version} + provided + org.eclipse.jetty jetty-util ${jetty.version} + provided org.eclipse.jetty jetty-security ${jetty.version} + provided org.eclipse.jetty jetty-plus ${jetty.version} + provided org.eclipse.jetty jetty-server ${jetty.version} + provided com.google.guava @@ -363,6 +394,8 @@ 14.0.1 provided + + org.apache.commons commons-lang3 @@ -371,7 +404,7 @@ commons-codec commons-codec - 1.5 + 1.10 org.apache.commons @@ -521,30 +554,40 @@ ${derby.version} - com.codahale.metrics + io.dropwizard.metrics metrics-core ${codahale.metrics.version} - com.codahale.metrics + io.dropwizard.metrics metrics-jvm ${codahale.metrics.version} - com.codahale.metrics + io.dropwizard.metrics metrics-json ${codahale.metrics.version} - com.codahale.metrics + io.dropwizard.metrics metrics-ganglia ${codahale.metrics.version} - com.codahale.metrics + io.dropwizard.metrics metrics-graphite ${codahale.metrics.version} + + com.fasterxml.jackson.core + jackson-databind + ${fasterxml.jackson.version} + + + com.fasterxml.jackson.module + jackson-module-scala_2.10 + ${fasterxml.jackson.version} + org.scala-lang scala-compiler @@ -576,19 +619,6 @@ 2.2.1 test - - org.easymock - easymockclassextension - 3.1 - test - - - - asm - asm - 3.3.1 - test - org.mockito mockito-all @@ -896,6 +926,16 @@ ${codehaus.jackson.version} ${hadoop.deps.scope} + + org.codehaus.jackson + jackson-xc + ${codehaus.jackson.version} + + + org.codehaus.jackson + jackson-jaxrs + ${codehaus.jackson.version} + ${hive.group} hive-beeline @@ -922,6 +962,10 @@ com.esotericsoftware.kryo kryo + + org.apache.avro + avro-mapred + @@ -1039,6 +1083,12 @@ scala-maven-plugin 3.2.0 + + eclipse-add-source + + add-source + + scala-compile-first process-resources @@ -1121,13 +1171,19 @@ ${project.build.directory}/surefire-reports -Xmx3g -XX:MaxPermSize=${MaxPermGen} -XX:ReservedCodeCacheSize=512m + + + ${test_classpath} + true ${session.executionRootDirectory} 1 false false - ${test_classpath} true false @@ -1264,7 +1320,10 @@ - + org.apache.maven.plugins maven-shade-plugin @@ -1273,9 +1332,44 @@ false + org.spark-project.spark:unused + + org.eclipse.jetty:jetty-io + org.eclipse.jetty:jetty-http + org.eclipse.jetty:jetty-continuation + org.eclipse.jetty:jetty-servlet + org.eclipse.jetty:jetty-plus + org.eclipse.jetty:jetty-security + org.eclipse.jetty:jetty-util + org.eclipse.jetty:jetty-server + com.google.guava:guava + + + org.eclipse.jetty + org.spark-project.jetty + + org.eclipse.jetty.** + + + + com.google.common + org.spark-project.guava + + + com/google/common/base/Absent* + com/google/common/base/Function + com/google/common/base/Optional* + com/google/common/base/Present* + com/google/common/base/Supplier + + + @@ -1468,6 +1562,7 @@ 2.5.0 0.98.7-hadoop2 hadoop2 + 1.9.13 @@ -1476,10 +1571,11 @@ 2.3.0 2.5.0 - 0.9.0 + 0.9.3 0.98.7-hadoop2 3.1.1 hadoop2 + 1.9.13 @@ -1488,10 +1584,11 @@ 2.4.0 2.5.0 - 0.9.0 + 0.9.3 0.98.7-hadoop2 3.1.1 hadoop2 + 1.9.13 @@ -1507,7 +1604,7 @@ mapr3 1.0.3-mapr-3.0.3 - 2.3.0-mapr-4.0.0-FCS + 2.4.1-mapr-1408 0.94.17-mapr-1405 3.4.5-mapr-1406 @@ -1516,8 +1613,8 @@ mapr4 - 2.3.0-mapr-4.0.0-FCS - 2.3.0-mapr-4.0.0-FCS + 2.4.1-mapr-1408 + 2.4.1-mapr-1408 0.94.17-mapr-1405-4.0.0-FCS 3.4.5-mapr-1406 @@ -1577,6 +1674,7 @@ external/kafka + external/kafka-assembly diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index d3ea594245722..ee6229aa6bbe1 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -36,6 +36,7 @@ object MimaExcludes { case v if v.startsWith("1.3") => Seq( MimaBuild.excludeSparkPackage("deploy"), + MimaBuild.excludeSparkPackage("ml"), // These are needed if checking against the sbt build, since they are part of // the maven-generated artifacts in the 1.2 build. MimaBuild.excludeSparkPackage("unused"), @@ -52,6 +53,29 @@ object MimaExcludes { "org.apache.spark.mllib.linalg.Matrices.randn"), ProblemFilters.exclude[MissingMethodProblem]( "org.apache.spark.mllib.linalg.Matrices.rand") + ) ++ Seq( + // SPARK-5321 + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.SparseMatrix.transposeMultiply"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.Matrix.transpose"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.DenseMatrix.transposeMultiply"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.mllib.linalg.Matrix." + + "org$apache$spark$mllib$linalg$Matrix$_setter_$isTransposed_="), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.Matrix.isTransposed"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.Matrix.foreachActive") + ) ++ Seq( + // SPARK-5540 + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.recommendation.ALS.solveLeastSquares"), + // SPARK-5536 + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$^dateFeatures"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$^dateBlock") ) ++ Seq( // SPARK-3325 ProblemFilters.exclude[MissingMethodProblem]( @@ -78,6 +102,57 @@ object MimaExcludes { "org.apache.spark.TaskContext.taskAttemptId"), ProblemFilters.exclude[MissingMethodProblem]( "org.apache.spark.TaskContext.attemptNumber") + ) ++ Seq( + // SPARK-5166 Spark SQL API stabilization + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.Transformer.transform"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.Estimator.fit"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.ml.Transformer.transform"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.Pipeline.fit"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.PipelineModel.transform"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.ml.Estimator.fit"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.Evaluator.evaluate"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.ml.Evaluator.evaluate"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.tuning.CrossValidator.fit"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.tuning.CrossValidatorModel.transform"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.StandardScaler.fit"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.StandardScalerModel.transform"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.LogisticRegressionModel.transform"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.LogisticRegression.fit"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.evaluation.BinaryClassificationEvaluator.evaluate") + ) ++ Seq( + // SPARK-5270 + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.java.JavaRDDLike.isEmpty") + ) ++ Seq( + // SPARK-5430 + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.java.JavaRDDLike.treeReduce"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.java.JavaRDDLike.treeAggregate") + ) ++ Seq( + // SPARK-5297 Java FileStream do not work with custom key/values + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.streaming.api.java.JavaStreamingContext.fileStream") + ) ++ Seq( + // SPARK-5315 Spark Streaming Java API returns Scala DStream + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.streaming.api.java.JavaDStreamLike.reduceByWindow") + ) ++ Seq( + // SPARK-5461 Graph should have isCheckpointed, getCheckpointFiles methods + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.graphx.Graph.getCheckpointFiles"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.graphx.Graph.isCheckpointed") + ) ++ Seq( + // SPARK-4789 Standardize ML Prediction APIs + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.mllib.linalg.VectorUDT"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.mllib.linalg.VectorUDT.serialize"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.mllib.linalg.VectorUDT.sqlType") + ) ++ Seq( + // SPARK-4682 + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.RealClock"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.Clock"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.TestClock") ) case v if v.startsWith("1.2") => diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index b2c546da21c70..e4b1b96527fbd 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -44,8 +44,9 @@ object BuildCommons { sparkKinesisAsl) = Seq("yarn", "yarn-stable", "java8-tests", "ganglia-lgpl", "kinesis-asl").map(ProjectRef(buildLocation, _)) - val assemblyProjects@Seq(assembly, examples, networkYarn) = - Seq("assembly", "examples", "network-yarn").map(ProjectRef(buildLocation, _)) + val assemblyProjects@Seq(assembly, examples, networkYarn, streamingKafkaAssembly) = + Seq("assembly", "examples", "network-yarn", "streaming-kafka-assembly") + .map(ProjectRef(buildLocation, _)) val tools = ProjectRef(buildLocation, "tools") // Root project. @@ -114,17 +115,6 @@ object SparkBuild extends PomBuild { override val userPropertiesMap = System.getProperties.toMap - // Handle case where hadoop.version is set via profile. - // Needed only because we read back this property in sbt - // when we create the assembly jar. - val pom = loadEffectivePom(new File("pom.xml"), - profiles = profiles, - userProps = userPropertiesMap) - if (System.getProperty("hadoop.version") == null) { - System.setProperty("hadoop.version", - pom.getProperties.get("hadoop.version").asInstanceOf[String]) - } - lazy val MavenCompile = config("m2r") extend(Compile) lazy val publishLocalBoth = TaskKey[Unit]("publish-local", "publish local for m2 and ivy") @@ -187,6 +177,29 @@ object SparkBuild extends PomBuild { enable(Flume.settings)(streamingFlumeSink) + + /** + * Adds the ability to run the spark shell directly from SBT without building an assembly + * jar. + * + * Usage: `build/sbt sparkShell` + */ + val sparkShell = taskKey[Unit]("start a spark-shell.") + + enable(Seq( + connectInput in run := true, + fork := true, + outputStrategy in run := Some (StdoutOutput), + + javaOptions ++= Seq("-Xmx2G", "-XX:MaxPermSize=1g"), + + sparkShell := { + (runMain in Compile).toTask(" org.apache.spark.repl.Main -usejavacp").value + } + ))(assembly) + + enable(Seq(sparkShell := sparkShell in "assembly"))(spark) + // TODO: move this to its upstream project. override def projectDefinitions(baseDirectory: File): Seq[Project] = { super.projectDefinitions(baseDirectory).map { x => @@ -255,6 +268,7 @@ object SQL { |import org.apache.spark.sql.catalyst.plans.logical._ |import org.apache.spark.sql.catalyst.rules._ |import org.apache.spark.sql.catalyst.util._ + |import org.apache.spark.sql.Dsl._ |import org.apache.spark.sql.execution |import org.apache.spark.sql.test.TestSQLContext._ |import org.apache.spark.sql.types._ @@ -285,6 +299,7 @@ object Hive { |import org.apache.spark.sql.catalyst.plans.logical._ |import org.apache.spark.sql.catalyst.rules._ |import org.apache.spark.sql.catalyst.util._ + |import org.apache.spark.sql.Dsl._ |import org.apache.spark.sql.execution |import org.apache.spark.sql.hive._ |import org.apache.spark.sql.hive.test.TestHive._ @@ -303,14 +318,20 @@ object Assembly { import sbtassembly.Plugin._ import AssemblyKeys._ + val hadoopVersion = taskKey[String]("The version of hadoop that spark is compiled against.") + lazy val settings = assemblySettings ++ Seq( test in assembly := {}, - jarName in assembly <<= (version, moduleName) map { (v, mName) => - if (mName.contains("network-yarn")) { - // This must match the same name used in maven (see network/yarn/pom.xml) - "spark-" + v + "-yarn-shuffle.jar" + hadoopVersion := { + sys.props.get("hadoop.version") + .getOrElse(SbtPomKeys.effectivePom.value.getProperties.get("hadoop.version").asInstanceOf[String]) + }, + jarName in assembly <<= (version, moduleName, hadoopVersion) map { (v, mName, hv) => + if (mName.contains("streaming-kafka-assembly")) { + // This must match the same name used in maven (see external/kafka-assembly/pom.xml) + s"${mName}-${v}.jar" } else { - mName + "-" + v + "-hadoop" + System.getProperty("hadoop.version") + ".jar" + s"${mName}-${v}-hadoop${hv}.jar" } }, mergeStrategy in assembly := { @@ -323,7 +344,6 @@ object Assembly { case _ => MergeStrategy.first } ) - } object Unidoc { @@ -341,9 +361,16 @@ object Unidoc { publish := {}, unidocProjectFilter in(ScalaUnidoc, unidoc) := - inAnyProject -- inProjects(OldDeps.project, repl, examples, tools, catalyst, streamingFlumeSink, yarn), + inAnyProject -- inProjects(OldDeps.project, repl, examples, tools, streamingFlumeSink, yarn), unidocProjectFilter in(JavaUnidoc, unidoc) := - inAnyProject -- inProjects(OldDeps.project, repl, bagel, examples, tools, catalyst, streamingFlumeSink, yarn), + inAnyProject -- inProjects(OldDeps.project, repl, bagel, examples, tools, streamingFlumeSink, yarn), + + // Skip actual catalyst, but include the subproject. + // Catalyst is not public API and contains quasiquotes which break scaladoc. + unidocAllSources in (ScalaUnidoc, unidoc) := { + (unidocAllSources in (ScalaUnidoc, unidoc)).value + .map(_.filterNot(_.getCanonicalPath.contains("sql/catalyst"))) + }, // Skip class names containing $ and some internal packages in Javadocs unidocAllSources in (JavaUnidoc, unidoc) := { @@ -356,6 +383,7 @@ object Unidoc { .map(_.filterNot(_.getCanonicalPath.contains("executor"))) .map(_.filterNot(_.getCanonicalPath.contains("python"))) .map(_.filterNot(_.getCanonicalPath.contains("collection"))) + .map(_.filterNot(_.getCanonicalPath.contains("sql/catalyst"))) }, // Javadoc options: create a window title, and group key packages on index page @@ -378,7 +406,10 @@ object Unidoc { ), "-group", "Spark SQL", packageList("sql.api.java", "sql.api.java.types", "sql.hive.api.java"), "-noqualifier", "java.lang" - ) + ), + + // Group similar methods together based on the @group annotation. + scalacOptions in (ScalaUnidoc, unidoc) ++= Seq("-groups") ) } @@ -388,6 +419,10 @@ object TestSettings { lazy val settings = Seq ( // Fork new JVMs for tests and set Java options for those fork := true, + // Setting SPARK_DIST_CLASSPATH is a simple way to make sure any child processes + // launched by the tests have access to the correct test-time classpath. + envVars in Test += ("SPARK_DIST_CLASSPATH" -> + (fullClasspath in Test).value.files.map(_.getAbsolutePath).mkString(":").stripSuffix(":")), javaOptions in Test += "-Dspark.test.home=" + sparkHome, javaOptions in Test += "-Dspark.testing=1", javaOptions in Test += "-Dspark.port.maxRetries=100", @@ -400,10 +435,6 @@ object TestSettings { javaOptions in Test += "-ea", javaOptions in Test ++= "-Xmx3g -XX:PermSize=128M -XX:MaxNewSize=256m -XX:MaxPermSize=1g" .split(" ").toSeq, - // This places test scope jars on the classpath of executors during tests. - javaOptions in Test += - "-Dspark.executor.extraClassPath=" + (fullClasspath in Test).value.files. - map(_.getAbsolutePath).mkString(":").stripSuffix(":"), javaOptions += "-Xmx3g", // Show full stack trace and duration in test cases. testOptions in Test += Tests.Argument("-oDF"), diff --git a/project/build.properties b/project/build.properties index 32a3aeefaf9fb..064ec843da9ea 100644 --- a/project/build.properties +++ b/project/build.properties @@ -14,4 +14,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -sbt.version=0.13.6 +sbt.version=0.13.7 diff --git a/python/docs/conf.py b/python/docs/conf.py index e58d97ae6a746..163987dd8e5fa 100644 --- a/python/docs/conf.py +++ b/python/docs/conf.py @@ -48,16 +48,16 @@ # General information about the project. project = u'PySpark' -copyright = u'2014, Author' +copyright = u'' # The version info for the project you're documenting, acts as replacement for # |version| and |release|, also used in various other places throughout the # built documents. # # The short X.Y version. -version = '1.2-SNAPSHOT' +version = 'master' # The full version, including alpha/beta/rc tags. -release = '1.2-SNAPSHOT' +release = os.environ.get('RELEASE_VERSION', version) # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. @@ -97,6 +97,10 @@ # If true, keep warnings as "system message" paragraphs in the built documents. #keep_warnings = False +# -- Options for autodoc -------------------------------------------------- + +# Look at the first line of the docstring for function and method signatures. +autodoc_docstring_signature = True # -- Options for HTML output ---------------------------------------------- diff --git a/python/docs/index.rst b/python/docs/index.rst index 703bef644de28..d150de9d5c502 100644 --- a/python/docs/index.rst +++ b/python/docs/index.rst @@ -14,6 +14,7 @@ Contents: pyspark pyspark.sql pyspark.streaming + pyspark.ml pyspark.mllib diff --git a/python/docs/pyspark.ml.rst b/python/docs/pyspark.ml.rst new file mode 100644 index 0000000000000..4da6d4a74a299 --- /dev/null +++ b/python/docs/pyspark.ml.rst @@ -0,0 +1,26 @@ +pyspark.ml package +===================== + +Module Context +-------------- + +.. automodule:: pyspark.ml + :members: + :undoc-members: + :inherited-members: + +pyspark.ml.feature module +------------------------- + +.. automodule:: pyspark.ml.feature + :members: + :undoc-members: + :inherited-members: + +pyspark.ml.classification module +-------------------------------- + +.. automodule:: pyspark.ml.classification + :members: + :undoc-members: + :inherited-members: diff --git a/python/docs/pyspark.mllib.rst b/python/docs/pyspark.mllib.rst index 4548b8739ed91..b706c5e376ef4 100644 --- a/python/docs/pyspark.mllib.rst +++ b/python/docs/pyspark.mllib.rst @@ -1,16 +1,13 @@ pyspark.mllib package ===================== -Submodules ----------- - pyspark.mllib.classification module ----------------------------------- .. automodule:: pyspark.mllib.classification :members: :undoc-members: - :show-inheritance: + :inherited-members: pyspark.mllib.clustering module ------------------------------- @@ -18,7 +15,6 @@ pyspark.mllib.clustering module .. automodule:: pyspark.mllib.clustering :members: :undoc-members: - :show-inheritance: pyspark.mllib.feature module ------------------------------- @@ -42,7 +38,6 @@ pyspark.mllib.random module .. automodule:: pyspark.mllib.random :members: :undoc-members: - :show-inheritance: pyspark.mllib.recommendation module ----------------------------------- @@ -50,7 +45,6 @@ pyspark.mllib.recommendation module .. automodule:: pyspark.mllib.recommendation :members: :undoc-members: - :show-inheritance: pyspark.mllib.regression module ------------------------------- @@ -58,7 +52,7 @@ pyspark.mllib.regression module .. automodule:: pyspark.mllib.regression :members: :undoc-members: - :show-inheritance: + :inherited-members: pyspark.mllib.stat module ------------------------- @@ -66,7 +60,6 @@ pyspark.mllib.stat module .. automodule:: pyspark.mllib.stat :members: :undoc-members: - :show-inheritance: pyspark.mllib.tree module ------------------------- @@ -74,7 +67,7 @@ pyspark.mllib.tree module .. automodule:: pyspark.mllib.tree :members: :undoc-members: - :show-inheritance: + :inherited-members: pyspark.mllib.util module ------------------------- @@ -82,4 +75,3 @@ pyspark.mllib.util module .. automodule:: pyspark.mllib.util :members: :undoc-members: - :show-inheritance: diff --git a/python/docs/pyspark.rst b/python/docs/pyspark.rst index e81be3b6cb796..0df12c49ad033 100644 --- a/python/docs/pyspark.rst +++ b/python/docs/pyspark.rst @@ -9,6 +9,7 @@ Subpackages pyspark.sql pyspark.streaming + pyspark.ml pyspark.mllib Contents diff --git a/python/docs/pyspark.sql.rst b/python/docs/pyspark.sql.rst index 65b3650ae10ab..6259379ed05b7 100644 --- a/python/docs/pyspark.sql.rst +++ b/python/docs/pyspark.sql.rst @@ -1,10 +1,23 @@ pyspark.sql module ================== -Module contents ---------------- +Module Context +-------------- .. automodule:: pyspark.sql :members: :undoc-members: - :show-inheritance: + + +pyspark.sql.types module +------------------------ +.. automodule:: pyspark.sql.types + :members: + :undoc-members: + + +pyspark.sql.functions module +---------------------------- +.. automodule:: pyspark.sql.functions + :members: + :undoc-members: diff --git a/python/pyspark/__init__.py b/python/pyspark/__init__.py index 9556e4718e585..5f70ac6ed8fe6 100644 --- a/python/pyspark/__init__.py +++ b/python/pyspark/__init__.py @@ -22,17 +22,17 @@ - :class:`SparkContext`: Main entry point for Spark functionality. - - L{RDD} + - :class:`RDD`: A Resilient Distributed Dataset (RDD), the basic abstraction in Spark. - - L{Broadcast} + - :class:`Broadcast`: A broadcast variable that gets reused across tasks. - - L{Accumulator} + - :class:`Accumulator`: An "add-only" shared variable that tasks can only add values to. - - L{SparkConf} + - :class:`SparkConf`: For configuring Spark. - - L{SparkFiles} + - :class:`SparkFiles`: Access files shipped with jobs. - - L{StorageLevel} + - :class:`StorageLevel`: Finer-grained cache persistence levels. """ @@ -45,6 +45,8 @@ from pyspark.accumulators import Accumulator, AccumulatorParam from pyspark.broadcast import Broadcast from pyspark.serializers import MarshalSerializer, PickleSerializer +from pyspark.status import * +from pyspark.profiler import Profiler, BasicProfiler # for back compatibility from pyspark.sql import SQLContext, HiveContext, SchemaRDD, Row @@ -52,4 +54,5 @@ __all__ = [ "SparkConf", "SparkContext", "SparkFiles", "RDD", "StorageLevel", "Broadcast", "Accumulator", "AccumulatorParam", "MarshalSerializer", "PickleSerializer", + "StatusTracker", "SparkJobInfo", "SparkStageInfo", "Profiler", "BasicProfiler", ] diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py index b8cdbbe3cf2b6..ccbca67656c8d 100644 --- a/python/pyspark/accumulators.py +++ b/python/pyspark/accumulators.py @@ -215,21 +215,6 @@ def addInPlace(self, value1, value2): COMPLEX_ACCUMULATOR_PARAM = AddingAccumulatorParam(0.0j) -class PStatsParam(AccumulatorParam): - """PStatsParam is used to merge pstats.Stats""" - - @staticmethod - def zero(value): - return None - - @staticmethod - def addInPlace(value1, value2): - if value1 is None: - return value2 - value1.add(value2) - return value1 - - class _UpdateRequestHandler(SocketServer.StreamRequestHandler): """ diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 64f6a3ca6bf4c..6011caf9f1c5a 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -20,7 +20,6 @@ import sys from threading import Lock from tempfile import NamedTemporaryFile -import atexit from pyspark import accumulators from pyspark.accumulators import Accumulator @@ -33,6 +32,8 @@ from pyspark.storagelevel import StorageLevel from pyspark.rdd import RDD from pyspark.traceback_utils import CallSite, first_spark_call +from pyspark.status import StatusTracker +from pyspark.profiler import ProfilerCollector, BasicProfiler from py4j.java_collections import ListConverter @@ -64,9 +65,11 @@ class SparkContext(object): _lock = Lock() _python_includes = None # zip and egg files that need to be added to PYTHONPATH + PACKAGE_EXTENSIONS = ('.zip', '.egg', '.jar') + def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None, environment=None, batchSize=0, serializer=PickleSerializer(), conf=None, - gateway=None, jsc=None): + gateway=None, jsc=None, profiler_cls=BasicProfiler): """ Create a new SparkContext. At least the master and app name should be set, either through the named parameters here or through C{conf}. @@ -88,6 +91,9 @@ def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None, :param conf: A L{SparkConf} object setting Spark properties. :param gateway: Use an existing gateway and JVM, otherwise a new JVM will be instantiated. + :param jsc: The JavaSparkContext instance (optional). + :param profiler_cls: A class of custom Profiler used to do profiling + (default is pyspark.profiler.BasicProfiler). >>> from pyspark.context import SparkContext @@ -102,14 +108,14 @@ def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None, SparkContext._ensure_initialized(self, gateway=gateway) try: self._do_init(master, appName, sparkHome, pyFiles, environment, batchSize, serializer, - conf, jsc) + conf, jsc, profiler_cls) except: # If an error occurs, clean up in order to allow future SparkContext creation: self.stop() raise def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize, serializer, - conf, jsc): + conf, jsc, profiler_cls): self.environment = environment or {} self._conf = conf or SparkConf(_jvm=self._jvm) self._batchSize = batchSize # -1 represents an unlimited batch size @@ -182,17 +188,22 @@ def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize, for path in self._conf.get("spark.submit.pyFiles", "").split(","): if path != "": (dirname, filename) = os.path.split(path) - if filename.lower().endswith("zip") or filename.lower().endswith("egg"): + if filename[-4:].lower() in self.PACKAGE_EXTENSIONS: self._python_includes.append(filename) sys.path.insert(1, os.path.join(SparkFiles.getRootDirectory(), filename)) # Create a temporary directory inside spark.local.dir: local_dir = self._jvm.org.apache.spark.util.Utils.getLocalDir(self._jsc.sc().conf()) self._temp_dir = \ - self._jvm.org.apache.spark.util.Utils.createTempDir(local_dir).getAbsolutePath() + self._jvm.org.apache.spark.util.Utils.createTempDir(local_dir, "pyspark") \ + .getAbsolutePath() # profiling stats collected for each PythonRDD - self._profile_stats = [] + if self._conf.get("spark.python.profile", "false") == "true": + dump_path = self._conf.get("spark.python.profile.dump", None) + self.profiler_collector = ProfilerCollector(profiler_cls, dump_path) + else: + self.profiler_collector = None def _initialize_context(self, jconf): """ @@ -229,6 +240,14 @@ def _ensure_initialized(cls, instance=None, gateway=None): else: SparkContext._active_spark_context = instance + def __getnewargs__(self): + # This method is called when attempting to pickle SparkContext, which is always an error: + raise Exception( + "It appears that you are attempting to reference SparkContext from a broadcast " + "variable, action, or transforamtion. SparkContext can only be used on the driver, " + "not in code that it run on workers. For more information, see SPARK-5063." + ) + def __enter__(self): """ Enable 'with SparkContext(...) as sc: app(sc)' syntax. @@ -689,7 +708,7 @@ def addPyFile(self, path): self.addFile(path) (dirname, filename) = os.path.split(path) # dirname may be directory or HDFS/S3 prefix - if filename.endswith('.zip') or filename.endswith('.ZIP') or filename.endswith('.egg'): + if filename[-4:].lower() in self.PACKAGE_EXTENSIONS: self._python_includes.append(filename) # for tests in local mode sys.path.insert(1, os.path.join(SparkFiles.getRootDirectory(), filename)) @@ -792,6 +811,12 @@ def cancelAllJobs(self): """ self._jsc.sc().cancelAllJobs() + def statusTracker(self): + """ + Return :class:`StatusTracker` object + """ + return StatusTracker(self._jsc.statusTracker()) + def runJob(self, rdd, partitionFunc, partitions=None, allowLocal=False): """ Executes the given partitionFunc on the specified set of partitions, @@ -818,39 +843,14 @@ def runJob(self, rdd, partitionFunc, partitions=None, allowLocal=False): it = self._jvm.PythonRDD.runJob(self._jsc.sc(), mappedRDD._jrdd, javaPartitions, allowLocal) return list(mappedRDD._collect_iterator_through_file(it)) - def _add_profile(self, id, profileAcc): - if not self._profile_stats: - dump_path = self._conf.get("spark.python.profile.dump") - if dump_path: - atexit.register(self.dump_profiles, dump_path) - else: - atexit.register(self.show_profiles) - - self._profile_stats.append([id, profileAcc, False]) - def show_profiles(self): """ Print the profile stats to stdout """ - for i, (id, acc, showed) in enumerate(self._profile_stats): - stats = acc.value - if not showed and stats: - print "=" * 60 - print "Profile of RDD" % id - print "=" * 60 - stats.sort_stats("time", "cumulative").print_stats() - # mark it as showed - self._profile_stats[i][2] = True + self.profiler_collector.show_profiles() def dump_profiles(self, path): """ Dump the profile stats into directory `path` """ - if not os.path.exists(path): - os.makedirs(path) - for id, acc, _ in self._profile_stats: - stats = acc.value - if stats: - p = os.path.join(path, "rdd_%d.pstats" % id) - stats.dump_stats(p) - self._profile_stats = [] + self.profiler_collector.dump_profiles(path) def _test(): diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py index a975dc19cb78e..936857e75c7e9 100644 --- a/python/pyspark/java_gateway.py +++ b/python/pyspark/java_gateway.py @@ -17,19 +17,20 @@ import atexit import os -import sys +import select import signal import shlex +import socket import platform from subprocess import Popen, PIPE -from threading import Thread from py4j.java_gateway import java_import, JavaGateway, GatewayClient +from pyspark.serializers import read_int + def launch_gateway(): SPARK_HOME = os.environ["SPARK_HOME"] - gateway_port = -1 if "PYSPARK_GATEWAY_PORT" in os.environ: gateway_port = int(os.environ["PYSPARK_GATEWAY_PORT"]) else: @@ -41,36 +42,42 @@ def launch_gateway(): submit_args = submit_args if submit_args is not None else "" submit_args = shlex.split(submit_args) command = [os.path.join(SPARK_HOME, script)] + submit_args + ["pyspark-shell"] + + # Start a socket that will be used by PythonGatewayServer to communicate its port to us + callback_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + callback_socket.bind(('127.0.0.1', 0)) + callback_socket.listen(1) + callback_host, callback_port = callback_socket.getsockname() + env = dict(os.environ) + env['_PYSPARK_DRIVER_CALLBACK_HOST'] = callback_host + env['_PYSPARK_DRIVER_CALLBACK_PORT'] = str(callback_port) + + # Launch the Java gateway. + # We open a pipe to stdin so that the Java gateway can die when the pipe is broken if not on_windows: # Don't send ctrl-c / SIGINT to the Java gateway: def preexec_func(): signal.signal(signal.SIGINT, signal.SIG_IGN) - env = dict(os.environ) env["IS_SUBPROCESS"] = "1" # tell JVM to exit after python exits - proc = Popen(command, stdout=PIPE, stdin=PIPE, preexec_fn=preexec_func, env=env) + proc = Popen(command, stdin=PIPE, preexec_fn=preexec_func, env=env) else: # preexec_fn not supported on Windows - proc = Popen(command, stdout=PIPE, stdin=PIPE) + proc = Popen(command, stdin=PIPE, env=env) - try: - # Determine which ephemeral port the server started on: - gateway_port = proc.stdout.readline() - gateway_port = int(gateway_port) - except ValueError: - # Grab the remaining lines of stdout - (stdout, _) = proc.communicate() - exit_code = proc.poll() - error_msg = "Launching GatewayServer failed" - error_msg += " with exit code %d!\n" % exit_code if exit_code else "!\n" - error_msg += "Warning: Expected GatewayServer to output a port, but found " - if gateway_port == "" and stdout == "": - error_msg += "no output.\n" - else: - error_msg += "the following:\n\n" - error_msg += "--------------------------------------------------------------\n" - error_msg += gateway_port + stdout - error_msg += "--------------------------------------------------------------\n" - raise Exception(error_msg) + gateway_port = None + # We use select() here in order to avoid blocking indefinitely if the subprocess dies + # before connecting + while gateway_port is None and proc.poll() is None: + timeout = 1 # (seconds) + readable, _, _ = select.select([callback_socket], [], [], timeout) + if callback_socket in readable: + gateway_connection = callback_socket.accept()[0] + # Determine which ephemeral port the server started on: + gateway_port = read_int(gateway_connection.makefile()) + gateway_connection.close() + callback_socket.close() + if gateway_port is None: + raise Exception("Java gateway process exited before sending the driver its port number") # In Windows, ensure the Java child processes do not linger after Python has exited. # In UNIX-based systems, the child process can kill itself on broken pipe (i.e. when @@ -88,21 +95,6 @@ def killChild(): Popen(["cmd", "/c", "taskkill", "/f", "/t", "/pid", str(proc.pid)]) atexit.register(killChild) - # Create a thread to echo output from the GatewayServer, which is required - # for Java log output to show up: - class EchoOutputThread(Thread): - - def __init__(self, stream): - Thread.__init__(self) - self.daemon = True - self.stream = stream - - def run(self): - while True: - line = self.stream.readline() - sys.stderr.write(line) - EchoOutputThread(proc.stdout).start() - # Connect to the gateway gateway = JavaGateway(GatewayClient(port=gateway_port), auto_convert=False) @@ -111,10 +103,9 @@ def run(self): java_import(gateway.jvm, "org.apache.spark.api.java.*") java_import(gateway.jvm, "org.apache.spark.api.python.*") java_import(gateway.jvm, "org.apache.spark.mllib.api.python.*") - java_import(gateway.jvm, "org.apache.spark.sql.SQLContext") - java_import(gateway.jvm, "org.apache.spark.sql.hive.HiveContext") - java_import(gateway.jvm, "org.apache.spark.sql.hive.LocalHiveContext") - java_import(gateway.jvm, "org.apache.spark.sql.hive.TestHiveContext") + # TODO(davies): move into sql + java_import(gateway.jvm, "org.apache.spark.sql.*") + java_import(gateway.jvm, "org.apache.spark.sql.hive.*") java_import(gateway.jvm, "scala.Tuple2") return gateway diff --git a/python/pyspark/join.py b/python/pyspark/join.py index b4a844713745a..efc1ef9396412 100644 --- a/python/pyspark/join.py +++ b/python/pyspark/join.py @@ -35,8 +35,8 @@ def _do_python_join(rdd, other, numPartitions, dispatch): - vs = rdd.map(lambda (k, v): (k, (1, v))) - ws = other.map(lambda (k, v): (k, (2, v))) + vs = rdd.mapValues(lambda v: (1, v)) + ws = other.mapValues(lambda v: (2, v)) return vs.union(ws).groupByKey(numPartitions).flatMapValues(lambda x: dispatch(x.__iter__())) @@ -98,8 +98,8 @@ def dispatch(seq): def python_cogroup(rdds, numPartitions): def make_mapper(i): - return lambda (k, v): (k, (i, v)) - vrdds = [rdd.map(make_mapper(i)) for i, rdd in enumerate(rdds)] + return lambda v: (i, v) + vrdds = [rdd.mapValues(make_mapper(i)) for i, rdd in enumerate(rdds)] union_vrdds = reduce(lambda acc, other: acc.union(other), vrdds) rdd_len = len(vrdds) diff --git a/sbin/spark-executor b/python/pyspark/ml/__init__.py old mode 100755 new mode 100644 similarity index 72% rename from sbin/spark-executor rename to python/pyspark/ml/__init__.py index 674ce906d9421..47fed80f42e13 --- a/sbin/spark-executor +++ b/python/pyspark/ml/__init__.py @@ -1,5 +1,3 @@ -#!/bin/sh - # # Licensed to the Apache Software Foundation (ASF) under one or more # contributor license agreements. See the NOTICE file distributed with @@ -17,10 +15,7 @@ # limitations under the License. # -FWDIR="$(cd "`dirname "$0"`"/..; pwd)" - -export PYTHONPATH="$FWDIR/python:$PYTHONPATH" -export PYTHONPATH="$FWDIR/python/lib/py4j-0.8.2.1-src.zip:$PYTHONPATH" +from pyspark.ml.param import * +from pyspark.ml.pipeline import * -echo "Running spark-executor with framework dir = $FWDIR" -exec "$FWDIR"/bin/spark-class org.apache.spark.executor.MesosExecutorBackend +__all__ = ["Param", "Params", "Transformer", "Estimator", "Pipeline"] diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py new file mode 100644 index 0000000000000..4ff7463498cce --- /dev/null +++ b/python/pyspark/ml/classification.py @@ -0,0 +1,101 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from pyspark.ml.util import keyword_only +from pyspark.ml.wrapper import JavaEstimator, JavaModel +from pyspark.ml.param.shared import HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter,\ + HasRegParam +from pyspark.mllib.common import inherit_doc + + +__all__ = ['LogisticRegression', 'LogisticRegressionModel'] + + +@inherit_doc +class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter, + HasRegParam): + """ + Logistic regression. + + >>> from pyspark.sql import Row + >>> from pyspark.mllib.linalg import Vectors + >>> df = sc.parallelize([ + ... Row(label=1.0, features=Vectors.dense(1.0)), + ... Row(label=0.0, features=Vectors.sparse(1, [], []))]).toDF() + >>> lr = LogisticRegression(maxIter=5, regParam=0.01) + >>> model = lr.fit(df) + >>> test0 = sc.parallelize([Row(features=Vectors.dense(-1.0))]).toDF() + >>> print model.transform(test0).head().prediction + 0.0 + >>> test1 = sc.parallelize([Row(features=Vectors.sparse(1, [0], [1.0]))]).toDF() + >>> print model.transform(test1).head().prediction + 1.0 + >>> lr.setParams("vector") + Traceback (most recent call last): + ... + TypeError: Method setParams forces keyword arguments. + """ + _java_class = "org.apache.spark.ml.classification.LogisticRegression" + + @keyword_only + def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", + maxIter=100, regParam=0.1): + """ + __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ + maxIter=100, regParam=0.1) + """ + super(LogisticRegression, self).__init__() + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) + + @keyword_only + def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", + maxIter=100, regParam=0.1): + """ + setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ + maxIter=100, regParam=0.1) + Sets params for logistic regression. + """ + kwargs = self.setParams._input_kwargs + return self._set_params(**kwargs) + + def _create_model(self, java_model): + return LogisticRegressionModel(java_model) + + +class LogisticRegressionModel(JavaModel): + """ + Model fitted by LogisticRegression. + """ + + +if __name__ == "__main__": + import doctest + from pyspark.context import SparkContext + from pyspark.sql import SQLContext + globs = globals().copy() + # The small batch size here ensures that we see multiple batches, + # even in these small test examples: + sc = SparkContext("local[2]", "ml.feature tests") + sqlCtx = SQLContext(sc) + globs['sc'] = sc + globs['sqlCtx'] = sqlCtx + (failure_count, test_count) = doctest.testmod( + globs=globs, optionflags=doctest.ELLIPSIS) + sc.stop() + if failure_count: + exit(-1) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py new file mode 100644 index 0000000000000..433b4fb5d22bf --- /dev/null +++ b/python/pyspark/ml/feature.py @@ -0,0 +1,127 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from pyspark.ml.param.shared import HasInputCol, HasOutputCol, HasNumFeatures +from pyspark.ml.util import keyword_only +from pyspark.ml.wrapper import JavaTransformer +from pyspark.mllib.common import inherit_doc + +__all__ = ['Tokenizer', 'HashingTF'] + + +@inherit_doc +class Tokenizer(JavaTransformer, HasInputCol, HasOutputCol): + """ + A tokenizer that converts the input string to lowercase and then + splits it by white spaces. + + >>> from pyspark.sql import Row + >>> df = sc.parallelize([Row(text="a b c")]).toDF() + >>> tokenizer = Tokenizer(inputCol="text", outputCol="words") + >>> print tokenizer.transform(df).head() + Row(text=u'a b c', words=[u'a', u'b', u'c']) + >>> # Change a parameter. + >>> print tokenizer.setParams(outputCol="tokens").transform(df).head() + Row(text=u'a b c', tokens=[u'a', u'b', u'c']) + >>> # Temporarily modify a parameter. + >>> print tokenizer.transform(df, {tokenizer.outputCol: "words"}).head() + Row(text=u'a b c', words=[u'a', u'b', u'c']) + >>> print tokenizer.transform(df).head() + Row(text=u'a b c', tokens=[u'a', u'b', u'c']) + >>> # Must use keyword arguments to specify params. + >>> tokenizer.setParams("text") + Traceback (most recent call last): + ... + TypeError: Method setParams forces keyword arguments. + """ + + _java_class = "org.apache.spark.ml.feature.Tokenizer" + + @keyword_only + def __init__(self, inputCol="input", outputCol="output"): + """ + __init__(self, inputCol="input", outputCol="output") + """ + super(Tokenizer, self).__init__() + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) + + @keyword_only + def setParams(self, inputCol="input", outputCol="output"): + """ + setParams(self, inputCol="input", outputCol="output") + Sets params for this Tokenizer. + """ + kwargs = self.setParams._input_kwargs + return self._set_params(**kwargs) + + +@inherit_doc +class HashingTF(JavaTransformer, HasInputCol, HasOutputCol, HasNumFeatures): + """ + Maps a sequence of terms to their term frequencies using the + hashing trick. + + >>> from pyspark.sql import Row + >>> df = sc.parallelize([Row(words=["a", "b", "c"])]).toDF() + >>> hashingTF = HashingTF(numFeatures=10, inputCol="words", outputCol="features") + >>> print hashingTF.transform(df).head().features + (10,[7,8,9],[1.0,1.0,1.0]) + >>> print hashingTF.setParams(outputCol="freqs").transform(df).head().freqs + (10,[7,8,9],[1.0,1.0,1.0]) + >>> params = {hashingTF.numFeatures: 5, hashingTF.outputCol: "vector"} + >>> print hashingTF.transform(df, params).head().vector + (5,[2,3,4],[1.0,1.0,1.0]) + """ + + _java_class = "org.apache.spark.ml.feature.HashingTF" + + @keyword_only + def __init__(self, numFeatures=1 << 18, inputCol="input", outputCol="output"): + """ + __init__(self, numFeatures=1 << 18, inputCol="input", outputCol="output") + """ + super(HashingTF, self).__init__() + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) + + @keyword_only + def setParams(self, numFeatures=1 << 18, inputCol="input", outputCol="output"): + """ + setParams(self, numFeatures=1 << 18, inputCol="input", outputCol="output") + Sets params for this HashingTF. + """ + kwargs = self.setParams._input_kwargs + return self._set_params(**kwargs) + + +if __name__ == "__main__": + import doctest + from pyspark.context import SparkContext + from pyspark.sql import SQLContext + globs = globals().copy() + # The small batch size here ensures that we see multiple batches, + # even in these small test examples: + sc = SparkContext("local[2]", "ml.feature tests") + sqlCtx = SQLContext(sc) + globs['sc'] = sc + globs['sqlCtx'] = sqlCtx + (failure_count, test_count) = doctest.testmod( + globs=globs, optionflags=doctest.ELLIPSIS) + sc.stop() + if failure_count: + exit(-1) diff --git a/python/pyspark/ml/param/__init__.py b/python/pyspark/ml/param/__init__.py new file mode 100644 index 0000000000000..e3a53dd780c4c --- /dev/null +++ b/python/pyspark/ml/param/__init__.py @@ -0,0 +1,90 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from abc import ABCMeta + +from pyspark.ml.util import Identifiable + + +__all__ = ['Param', 'Params'] + + +class Param(object): + """ + A param with self-contained documentation and optionally default value. + """ + + def __init__(self, parent, name, doc, defaultValue=None): + if not isinstance(parent, Identifiable): + raise ValueError("Parent must be identifiable but got type %s." % type(parent).__name__) + self.parent = parent + self.name = str(name) + self.doc = str(doc) + self.defaultValue = defaultValue + + def __str__(self): + return str(self.parent) + "-" + self.name + + def __repr__(self): + return "Param(parent=%r, name=%r, doc=%r, defaultValue=%r)" % \ + (self.parent, self.name, self.doc, self.defaultValue) + + +class Params(Identifiable): + """ + Components that take parameters. This also provides an internal + param map to store parameter values attached to the instance. + """ + + __metaclass__ = ABCMeta + + def __init__(self): + super(Params, self).__init__() + #: embedded param map + self.paramMap = {} + + @property + def params(self): + """ + Returns all params. The default implementation uses + :py:func:`dir` to get all attributes of type + :py:class:`Param`. + """ + return filter(lambda attr: isinstance(attr, Param), + [getattr(self, x) for x in dir(self) if x != "params"]) + + def _merge_params(self, params): + paramMap = self.paramMap.copy() + paramMap.update(params) + return paramMap + + @staticmethod + def _dummy(): + """ + Returns a dummy Params instance used as a placeholder to generate docs. + """ + dummy = Params() + dummy.uid = "undefined" + return dummy + + def _set_params(self, **kwargs): + """ + Sets params. + """ + for param, value in kwargs.iteritems(): + self.paramMap[getattr(self, param)] = value + return self diff --git a/python/pyspark/ml/param/_gen_shared_params.py b/python/pyspark/ml/param/_gen_shared_params.py new file mode 100644 index 0000000000000..5eb81106f116c --- /dev/null +++ b/python/pyspark/ml/param/_gen_shared_params.py @@ -0,0 +1,98 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +header = """# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +#""" + + +def _gen_param_code(name, doc, defaultValue): + """ + Generates Python code for a shared param class. + + :param name: param name + :param doc: param doc + :param defaultValue: string representation of the param + :return: code string + """ + # TODO: How to correctly inherit instance attributes? + template = '''class Has$Name(Params): + """ + Params with $name. + """ + + # a placeholder to make it appear in the generated doc + $name = Param(Params._dummy(), "$name", "$doc", $defaultValue) + + def __init__(self): + super(Has$Name, self).__init__() + #: param for $doc + self.$name = Param(self, "$name", "$doc", $defaultValue) + + def set$Name(self, value): + """ + Sets the value of :py:attr:`$name`. + """ + self.paramMap[self.$name] = value + return self + + def get$Name(self): + """ + Gets the value of $name or its default value. + """ + if self.$name in self.paramMap: + return self.paramMap[self.$name] + else: + return self.$name.defaultValue''' + + upperCamelName = name[0].upper() + name[1:] + return template \ + .replace("$name", name) \ + .replace("$Name", upperCamelName) \ + .replace("$doc", doc) \ + .replace("$defaultValue", defaultValue) + +if __name__ == "__main__": + print header + print "\n# DO NOT MODIFY. The code is generated by _gen_shared_params.py.\n" + print "from pyspark.ml.param import Param, Params\n\n" + shared = [ + ("maxIter", "max number of iterations", "100"), + ("regParam", "regularization constant", "0.1"), + ("featuresCol", "features column name", "'features'"), + ("labelCol", "label column name", "'label'"), + ("predictionCol", "prediction column name", "'prediction'"), + ("inputCol", "input column name", "'input'"), + ("outputCol", "output column name", "'output'"), + ("numFeatures", "number of features", "1 << 18")] + code = [] + for name, doc, defaultValue in shared: + code.append(_gen_param_code(name, doc, defaultValue)) + print "\n\n\n".join(code) diff --git a/python/pyspark/ml/param/shared.py b/python/pyspark/ml/param/shared.py new file mode 100644 index 0000000000000..586822f2de423 --- /dev/null +++ b/python/pyspark/ml/param/shared.py @@ -0,0 +1,260 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# DO NOT MODIFY. The code is generated by _gen_shared_params.py. + +from pyspark.ml.param import Param, Params + + +class HasMaxIter(Params): + """ + Params with maxIter. + """ + + # a placeholder to make it appear in the generated doc + maxIter = Param(Params._dummy(), "maxIter", "max number of iterations", 100) + + def __init__(self): + super(HasMaxIter, self).__init__() + #: param for max number of iterations + self.maxIter = Param(self, "maxIter", "max number of iterations", 100) + + def setMaxIter(self, value): + """ + Sets the value of :py:attr:`maxIter`. + """ + self.paramMap[self.maxIter] = value + return self + + def getMaxIter(self): + """ + Gets the value of maxIter or its default value. + """ + if self.maxIter in self.paramMap: + return self.paramMap[self.maxIter] + else: + return self.maxIter.defaultValue + + +class HasRegParam(Params): + """ + Params with regParam. + """ + + # a placeholder to make it appear in the generated doc + regParam = Param(Params._dummy(), "regParam", "regularization constant", 0.1) + + def __init__(self): + super(HasRegParam, self).__init__() + #: param for regularization constant + self.regParam = Param(self, "regParam", "regularization constant", 0.1) + + def setRegParam(self, value): + """ + Sets the value of :py:attr:`regParam`. + """ + self.paramMap[self.regParam] = value + return self + + def getRegParam(self): + """ + Gets the value of regParam or its default value. + """ + if self.regParam in self.paramMap: + return self.paramMap[self.regParam] + else: + return self.regParam.defaultValue + + +class HasFeaturesCol(Params): + """ + Params with featuresCol. + """ + + # a placeholder to make it appear in the generated doc + featuresCol = Param(Params._dummy(), "featuresCol", "features column name", 'features') + + def __init__(self): + super(HasFeaturesCol, self).__init__() + #: param for features column name + self.featuresCol = Param(self, "featuresCol", "features column name", 'features') + + def setFeaturesCol(self, value): + """ + Sets the value of :py:attr:`featuresCol`. + """ + self.paramMap[self.featuresCol] = value + return self + + def getFeaturesCol(self): + """ + Gets the value of featuresCol or its default value. + """ + if self.featuresCol in self.paramMap: + return self.paramMap[self.featuresCol] + else: + return self.featuresCol.defaultValue + + +class HasLabelCol(Params): + """ + Params with labelCol. + """ + + # a placeholder to make it appear in the generated doc + labelCol = Param(Params._dummy(), "labelCol", "label column name", 'label') + + def __init__(self): + super(HasLabelCol, self).__init__() + #: param for label column name + self.labelCol = Param(self, "labelCol", "label column name", 'label') + + def setLabelCol(self, value): + """ + Sets the value of :py:attr:`labelCol`. + """ + self.paramMap[self.labelCol] = value + return self + + def getLabelCol(self): + """ + Gets the value of labelCol or its default value. + """ + if self.labelCol in self.paramMap: + return self.paramMap[self.labelCol] + else: + return self.labelCol.defaultValue + + +class HasPredictionCol(Params): + """ + Params with predictionCol. + """ + + # a placeholder to make it appear in the generated doc + predictionCol = Param(Params._dummy(), "predictionCol", "prediction column name", 'prediction') + + def __init__(self): + super(HasPredictionCol, self).__init__() + #: param for prediction column name + self.predictionCol = Param(self, "predictionCol", "prediction column name", 'prediction') + + def setPredictionCol(self, value): + """ + Sets the value of :py:attr:`predictionCol`. + """ + self.paramMap[self.predictionCol] = value + return self + + def getPredictionCol(self): + """ + Gets the value of predictionCol or its default value. + """ + if self.predictionCol in self.paramMap: + return self.paramMap[self.predictionCol] + else: + return self.predictionCol.defaultValue + + +class HasInputCol(Params): + """ + Params with inputCol. + """ + + # a placeholder to make it appear in the generated doc + inputCol = Param(Params._dummy(), "inputCol", "input column name", 'input') + + def __init__(self): + super(HasInputCol, self).__init__() + #: param for input column name + self.inputCol = Param(self, "inputCol", "input column name", 'input') + + def setInputCol(self, value): + """ + Sets the value of :py:attr:`inputCol`. + """ + self.paramMap[self.inputCol] = value + return self + + def getInputCol(self): + """ + Gets the value of inputCol or its default value. + """ + if self.inputCol in self.paramMap: + return self.paramMap[self.inputCol] + else: + return self.inputCol.defaultValue + + +class HasOutputCol(Params): + """ + Params with outputCol. + """ + + # a placeholder to make it appear in the generated doc + outputCol = Param(Params._dummy(), "outputCol", "output column name", 'output') + + def __init__(self): + super(HasOutputCol, self).__init__() + #: param for output column name + self.outputCol = Param(self, "outputCol", "output column name", 'output') + + def setOutputCol(self, value): + """ + Sets the value of :py:attr:`outputCol`. + """ + self.paramMap[self.outputCol] = value + return self + + def getOutputCol(self): + """ + Gets the value of outputCol or its default value. + """ + if self.outputCol in self.paramMap: + return self.paramMap[self.outputCol] + else: + return self.outputCol.defaultValue + + +class HasNumFeatures(Params): + """ + Params with numFeatures. + """ + + # a placeholder to make it appear in the generated doc + numFeatures = Param(Params._dummy(), "numFeatures", "number of features", 1 << 18) + + def __init__(self): + super(HasNumFeatures, self).__init__() + #: param for number of features + self.numFeatures = Param(self, "numFeatures", "number of features", 1 << 18) + + def setNumFeatures(self, value): + """ + Sets the value of :py:attr:`numFeatures`. + """ + self.paramMap[self.numFeatures] = value + return self + + def getNumFeatures(self): + """ + Gets the value of numFeatures or its default value. + """ + if self.numFeatures in self.paramMap: + return self.paramMap[self.numFeatures] + else: + return self.numFeatures.defaultValue diff --git a/python/pyspark/ml/pipeline.py b/python/pyspark/ml/pipeline.py new file mode 100644 index 0000000000000..5233c5801e2e6 --- /dev/null +++ b/python/pyspark/ml/pipeline.py @@ -0,0 +1,170 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from abc import ABCMeta, abstractmethod + +from pyspark.ml.param import Param, Params +from pyspark.ml.util import keyword_only +from pyspark.mllib.common import inherit_doc + + +__all__ = ['Estimator', 'Transformer', 'Pipeline', 'PipelineModel'] + + +@inherit_doc +class Estimator(Params): + """ + Abstract class for estimators that fit models to data. + """ + + __metaclass__ = ABCMeta + + @abstractmethod + def fit(self, dataset, params={}): + """ + Fits a model to the input dataset with optional parameters. + + :param dataset: input dataset, which is an instance of + :py:class:`pyspark.sql.SchemaRDD` + :param params: an optional param map that overwrites embedded + params + :returns: fitted model + """ + raise NotImplementedError() + + +@inherit_doc +class Transformer(Params): + """ + Abstract class for transformers that transform one dataset into + another. + """ + + __metaclass__ = ABCMeta + + @abstractmethod + def transform(self, dataset, params={}): + """ + Transforms the input dataset with optional parameters. + + :param dataset: input dataset, which is an instance of + :py:class:`pyspark.sql.SchemaRDD` + :param params: an optional param map that overwrites embedded + params + :returns: transformed dataset + """ + raise NotImplementedError() + + +@inherit_doc +class Pipeline(Estimator): + """ + A simple pipeline, which acts as an estimator. A Pipeline consists + of a sequence of stages, each of which is either an + :py:class:`Estimator` or a :py:class:`Transformer`. When + :py:meth:`Pipeline.fit` is called, the stages are executed in + order. If a stage is an :py:class:`Estimator`, its + :py:meth:`Estimator.fit` method will be called on the input + dataset to fit a model. Then the model, which is a transformer, + will be used to transform the dataset as the input to the next + stage. If a stage is a :py:class:`Transformer`, its + :py:meth:`Transformer.transform` method will be called to produce + the dataset for the next stage. The fitted model from a + :py:class:`Pipeline` is an :py:class:`PipelineModel`, which + consists of fitted models and transformers, corresponding to the + pipeline stages. If there are no stages, the pipeline acts as an + identity transformer. + """ + + @keyword_only + def __init__(self, stages=[]): + """ + __init__(self, stages=[]) + """ + super(Pipeline, self).__init__() + #: Param for pipeline stages. + self.stages = Param(self, "stages", "pipeline stages") + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) + + def setStages(self, value): + """ + Set pipeline stages. + :param value: a list of transformers or estimators + :return: the pipeline instance + """ + self.paramMap[self.stages] = value + return self + + def getStages(self): + """ + Get pipeline stages. + """ + if self.stages in self.paramMap: + return self.paramMap[self.stages] + + @keyword_only + def setParams(self, stages=[]): + """ + setParams(self, stages=[]) + Sets params for Pipeline. + """ + kwargs = self.setParams._input_kwargs + return self._set_params(**kwargs) + + def fit(self, dataset, params={}): + paramMap = self._merge_params(params) + stages = paramMap[self.stages] + for stage in stages: + if not (isinstance(stage, Estimator) or isinstance(stage, Transformer)): + raise ValueError( + "Cannot recognize a pipeline stage of type %s." % type(stage).__name__) + indexOfLastEstimator = -1 + for i, stage in enumerate(stages): + if isinstance(stage, Estimator): + indexOfLastEstimator = i + transformers = [] + for i, stage in enumerate(stages): + if i <= indexOfLastEstimator: + if isinstance(stage, Transformer): + transformers.append(stage) + dataset = stage.transform(dataset, paramMap) + else: # must be an Estimator + model = stage.fit(dataset, paramMap) + transformers.append(model) + if i < indexOfLastEstimator: + dataset = model.transform(dataset, paramMap) + else: + transformers.append(stage) + return PipelineModel(transformers) + + +@inherit_doc +class PipelineModel(Transformer): + """ + Represents a compiled pipeline with transformers and fitted models. + """ + + def __init__(self, transformers): + super(PipelineModel, self).__init__() + self.transformers = transformers + + def transform(self, dataset, params={}): + paramMap = self._merge_params(params) + for t in self.transformers: + dataset = t.transform(dataset, paramMap) + return dataset diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py new file mode 100644 index 0000000000000..b627c2b4e930b --- /dev/null +++ b/python/pyspark/ml/tests.py @@ -0,0 +1,115 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Unit tests for Spark ML Python APIs. +""" + +import sys + +if sys.version_info[:2] <= (2, 6): + try: + import unittest2 as unittest + except ImportError: + sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') + sys.exit(1) +else: + import unittest + +from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase +from pyspark.sql import DataFrame +from pyspark.ml.param import Param +from pyspark.ml.pipeline import Transformer, Estimator, Pipeline + + +class MockDataset(DataFrame): + + def __init__(self): + self.index = 0 + + +class MockTransformer(Transformer): + + def __init__(self): + super(MockTransformer, self).__init__() + self.fake = Param(self, "fake", "fake", None) + self.dataset_index = None + self.fake_param_value = None + + def transform(self, dataset, params={}): + self.dataset_index = dataset.index + if self.fake in params: + self.fake_param_value = params[self.fake] + dataset.index += 1 + return dataset + + +class MockEstimator(Estimator): + + def __init__(self): + super(MockEstimator, self).__init__() + self.fake = Param(self, "fake", "fake", None) + self.dataset_index = None + self.fake_param_value = None + self.model = None + + def fit(self, dataset, params={}): + self.dataset_index = dataset.index + if self.fake in params: + self.fake_param_value = params[self.fake] + model = MockModel() + self.model = model + return model + + +class MockModel(MockTransformer, Transformer): + + def __init__(self): + super(MockModel, self).__init__() + + +class PipelineTests(PySparkTestCase): + + def test_pipeline(self): + dataset = MockDataset() + estimator0 = MockEstimator() + transformer1 = MockTransformer() + estimator2 = MockEstimator() + transformer3 = MockTransformer() + pipeline = Pipeline() \ + .setStages([estimator0, transformer1, estimator2, transformer3]) + pipeline_model = pipeline.fit(dataset, {estimator0.fake: 0, transformer1.fake: 1}) + self.assertEqual(0, estimator0.dataset_index) + self.assertEqual(0, estimator0.fake_param_value) + model0 = estimator0.model + self.assertEqual(0, model0.dataset_index) + self.assertEqual(1, transformer1.dataset_index) + self.assertEqual(1, transformer1.fake_param_value) + self.assertEqual(2, estimator2.dataset_index) + model2 = estimator2.model + self.assertIsNone(model2.dataset_index, "The model produced by the last estimator should " + "not be called during fit.") + dataset = pipeline_model.transform(dataset) + self.assertEqual(2, model0.dataset_index) + self.assertEqual(3, transformer1.dataset_index) + self.assertEqual(4, model2.dataset_index) + self.assertEqual(5, transformer3.dataset_index) + self.assertEqual(6, dataset.index) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py new file mode 100644 index 0000000000000..6f7f39c40eb5a --- /dev/null +++ b/python/pyspark/ml/util.py @@ -0,0 +1,47 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from functools import wraps +import uuid + + +def keyword_only(func): + """ + A decorator that forces keyword arguments in the wrapped method + and saves actual input keyword arguments in `_input_kwargs`. + """ + @wraps(func) + def wrapper(*args, **kwargs): + if len(args) > 1: + raise TypeError("Method %s forces keyword arguments." % func.__name__) + wrapper._input_kwargs = kwargs + return func(*args, **kwargs) + return wrapper + + +class Identifiable(object): + """ + Object with a unique ID. + """ + + def __init__(self): + #: A unique id for the object. The default implementation + #: concatenates the class name, "-", and 8 random hex chars. + self.uid = type(self).__name__ + "-" + uuid.uuid4().hex[:8] + + def __repr__(self): + return self.uid diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py new file mode 100644 index 0000000000000..4bae96f678388 --- /dev/null +++ b/python/pyspark/ml/wrapper.py @@ -0,0 +1,149 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from abc import ABCMeta + +from pyspark import SparkContext +from pyspark.sql import DataFrame +from pyspark.ml.param import Params +from pyspark.ml.pipeline import Estimator, Transformer +from pyspark.mllib.common import inherit_doc + + +def _jvm(): + """ + Returns the JVM view associated with SparkContext. Must be called + after SparkContext is initialized. + """ + jvm = SparkContext._jvm + if jvm: + return jvm + else: + raise AttributeError("Cannot load _jvm from SparkContext. Is SparkContext initialized?") + + +@inherit_doc +class JavaWrapper(Params): + """ + Utility class to help create wrapper classes from Java/Scala + implementations of pipeline components. + """ + + __metaclass__ = ABCMeta + + #: Fully-qualified class name of the wrapped Java component. + _java_class = None + + def _java_obj(self): + """ + Returns or creates a Java object. + """ + java_obj = _jvm() + for name in self._java_class.split("."): + java_obj = getattr(java_obj, name) + return java_obj() + + def _transfer_params_to_java(self, params, java_obj): + """ + Transforms the embedded params and additional params to the + input Java object. + :param params: additional params (overwriting embedded values) + :param java_obj: Java object to receive the params + """ + paramMap = self._merge_params(params) + for param in self.params: + if param in paramMap: + java_obj.set(param.name, paramMap[param]) + + def _empty_java_param_map(self): + """ + Returns an empty Java ParamMap reference. + """ + return _jvm().org.apache.spark.ml.param.ParamMap() + + def _create_java_param_map(self, params, java_obj): + paramMap = self._empty_java_param_map() + for param, value in params.items(): + if param.parent is self: + paramMap.put(java_obj.getParam(param.name), value) + return paramMap + + +@inherit_doc +class JavaEstimator(Estimator, JavaWrapper): + """ + Base class for :py:class:`Estimator`s that wrap Java/Scala + implementations. + """ + + __metaclass__ = ABCMeta + + def _create_model(self, java_model): + """ + Creates a model from the input Java model reference. + """ + return JavaModel(java_model) + + def _fit_java(self, dataset, params={}): + """ + Fits a Java model to the input dataset. + :param dataset: input dataset, which is an instance of + :py:class:`pyspark.sql.SchemaRDD` + :param params: additional params (overwriting embedded values) + :return: fitted Java model + """ + java_obj = self._java_obj() + self._transfer_params_to_java(params, java_obj) + return java_obj.fit(dataset._jdf, self._empty_java_param_map()) + + def fit(self, dataset, params={}): + java_model = self._fit_java(dataset, params) + return self._create_model(java_model) + + +@inherit_doc +class JavaTransformer(Transformer, JavaWrapper): + """ + Base class for :py:class:`Transformer`s that wrap Java/Scala + implementations. + """ + + __metaclass__ = ABCMeta + + def transform(self, dataset, params={}): + java_obj = self._java_obj() + self._transfer_params_to_java({}, java_obj) + java_param_map = self._create_java_param_map(params, java_obj) + return DataFrame(java_obj.transform(dataset._jdf, java_param_map), + dataset.sql_ctx) + + +@inherit_doc +class JavaModel(JavaTransformer): + """ + Base class for :py:class:`Model`s that wrap Java/Scala + implementations. + """ + + __metaclass__ = ABCMeta + + def __init__(self, java_model): + super(JavaTransformer, self).__init__() + self._java_model = java_model + + def _java_obj(self): + return self._java_model diff --git a/python/pyspark/mllib/__init__.py b/python/pyspark/mllib/__init__.py index c3217620e3c4e..6449800d9c120 100644 --- a/python/pyspark/mllib/__init__.py +++ b/python/pyspark/mllib/__init__.py @@ -19,7 +19,7 @@ Python bindings for MLlib. """ -# MLlib currently needs and NumPy 1.4+, so complain if lower +# MLlib currently needs NumPy 1.4+, so complain if lower import numpy if numpy.version.version < '1.4': diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py index e2492eef5bd6a..949db5705abd7 100644 --- a/python/pyspark/mllib/clustering.py +++ b/python/pyspark/mllib/clustering.py @@ -15,19 +15,22 @@ # limitations under the License. # +from numpy import array + +from pyspark import RDD from pyspark import SparkContext from pyspark.mllib.common import callMLlibFunc, callJavaFunc -from pyspark.mllib.linalg import SparseVector, _convert_to_vector +from pyspark.mllib.linalg import DenseVector, SparseVector, _convert_to_vector +from pyspark.mllib.stat.distribution import MultivariateGaussian -__all__ = ['KMeansModel', 'KMeans'] +__all__ = ['KMeansModel', 'KMeans', 'GaussianMixtureModel', 'GaussianMixture'] class KMeansModel(object): """A clustering model derived from the k-means method. - >>> from numpy import array - >>> data = array([0.0,0.0, 1.0,1.0, 9.0,8.0, 8.0,9.0]).reshape(4,2) + >>> data = array([0.0,0.0, 1.0,1.0, 9.0,8.0, 8.0,9.0]).reshape(4, 2) >>> model = KMeans.train( ... sc.parallelize(data), 2, maxIterations=10, runs=30, initializationMode="random") >>> model.predict(array([0.0, 0.0])) == model.predict(array([1.0, 1.0])) @@ -78,14 +81,95 @@ def predict(self, x): class KMeans(object): @classmethod - def train(cls, rdd, k, maxIterations=100, runs=1, initializationMode="k-means||"): + def train(cls, rdd, k, maxIterations=100, runs=1, initializationMode="k-means||", seed=None): """Train a k-means clustering model.""" model = callMLlibFunc("trainKMeansModel", rdd.map(_convert_to_vector), k, maxIterations, - runs, initializationMode) + runs, initializationMode, seed) centers = callJavaFunc(rdd.context, model.clusterCenters) return KMeansModel([c.toArray() for c in centers]) +class GaussianMixtureModel(object): + + """A clustering model derived from the Gaussian Mixture Model method. + + >>> clusterdata_1 = sc.parallelize(array([-0.1,-0.05,-0.01,-0.1, + ... 0.9,0.8,0.75,0.935, + ... -0.83,-0.68,-0.91,-0.76 ]).reshape(6, 2)) + >>> model = GaussianMixture.train(clusterdata_1, 3, convergenceTol=0.0001, + ... maxIterations=50, seed=10) + >>> labels = model.predict(clusterdata_1).collect() + >>> labels[0]==labels[1] + False + >>> labels[1]==labels[2] + True + >>> labels[4]==labels[5] + True + >>> clusterdata_2 = sc.parallelize(array([-5.1971, -2.5359, -3.8220, + ... -5.2211, -5.0602, 4.7118, + ... 6.8989, 3.4592, 4.6322, + ... 5.7048, 4.6567, 5.5026, + ... 4.5605, 5.2043, 6.2734]).reshape(5, 3)) + >>> model = GaussianMixture.train(clusterdata_2, 2, convergenceTol=0.0001, + ... maxIterations=150, seed=10) + >>> labels = model.predict(clusterdata_2).collect() + >>> labels[0]==labels[1]==labels[2] + True + >>> labels[3]==labels[4] + True + """ + + def __init__(self, weights, gaussians): + self.weights = weights + self.gaussians = gaussians + self.k = len(self.weights) + + def predict(self, x): + """ + Find the cluster to which the points in 'x' has maximum membership + in this model. + + :param x: RDD of data points. + :return: cluster_labels. RDD of cluster labels. + """ + if isinstance(x, RDD): + cluster_labels = self.predictSoft(x).map(lambda z: z.index(max(z))) + return cluster_labels + + def predictSoft(self, x): + """ + Find the membership of each point in 'x' to all mixture components. + + :param x: RDD of data points. + :return: membership_matrix. RDD of array of double values. + """ + if isinstance(x, RDD): + means, sigmas = zip(*[(g.mu, g.sigma) for g in self.gaussians]) + membership_matrix = callMLlibFunc("predictSoftGMM", x.map(_convert_to_vector), + self.weights, means, sigmas) + return membership_matrix + + +class GaussianMixture(object): + """ + Learning algorithm for Gaussian Mixtures using the expectation-maximization algorithm. + + :param data: RDD of data points + :param k: Number of components + :param convergenceTol: Threshold value to check the convergence criteria. Defaults to 1e-3 + :param maxIterations: Number of iterations. Default to 100 + :param seed: Random Seed + """ + @classmethod + def train(cls, rdd, k, convergenceTol=1e-3, maxIterations=100, seed=None): + """Train a Gaussian Mixture clustering model.""" + weight, mu, sigma = callMLlibFunc("trainGaussianMixture", + rdd.map(_convert_to_vector), k, + convergenceTol, maxIterations, seed) + mvg_obj = [MultivariateGaussian(mu[i], sigma[i]) for i in range(k)] + return GaussianMixtureModel(weight, mvg_obj) + + def _test(): import doctest globs = globals().copy() diff --git a/python/pyspark/mllib/common.py b/python/pyspark/mllib/common.py index 3c5ee66cd8b64..621591c26b77f 100644 --- a/python/pyspark/mllib/common.py +++ b/python/pyspark/mllib/common.py @@ -134,3 +134,20 @@ def __del__(self): def call(self, name, *a): """Call method of java_model""" return callJavaFunc(self._sc, getattr(self._java_model, name), *a) + + +def inherit_doc(cls): + """ + A decorator that makes a class inherit documentation from its parents. + """ + for name, func in vars(cls).items(): + # only inherit docstring for public functions + if name.startswith("_"): + continue + if not func.__doc__: + for parent in cls.__bases__: + parent_func = getattr(parent, name, None) + if parent_func and getattr(parent_func, "__doc__", None): + func.__doc__ = parent_func.__doc__ + break + return cls diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py index 10df6288065b8..0ffe092a07365 100644 --- a/python/pyspark/mllib/feature.py +++ b/python/pyspark/mllib/feature.py @@ -58,7 +58,8 @@ class Normalizer(VectorTransformer): For any 1 <= `p` < float('inf'), normalizes samples using sum(abs(vector) :sup:`p`) :sup:`(1/p)` as norm. - For `p` = float('inf'), max(abs(vector)) will be used as norm for normalization. + For `p` = float('inf'), max(abs(vector)) will be used as norm for + normalization. >>> v = Vectors.dense(range(3)) >>> nor = Normalizer(1) @@ -120,9 +121,14 @@ def transform(self, vector): """ Applies standardization transformation on a vector. + Note: In Python, transform cannot currently be used within + an RDD transformation or action. + Call transform directly on the RDD instead. + :param vector: Vector or RDD of Vector to be standardized. - :return: Standardized vector. If the variance of a column is zero, - it will return default `0.0` for the column with zero variance. + :return: Standardized vector. If the variance of a column is + zero, it will return default `0.0` for the column with + zero variance. """ return JavaVectorTransformer.transform(self, vector) @@ -148,9 +154,10 @@ def __init__(self, withMean=False, withStd=True): """ :param withMean: False by default. Centers the data with mean before scaling. It will build a dense output, so this - does not work on sparse input and will raise an exception. - :param withStd: True by default. Scales the data to unit standard - deviation. + does not work on sparse input and will raise an + exception. + :param withStd: True by default. Scales the data to unit + standard deviation. """ if not (withMean or withStd): warnings.warn("Both withMean and withStd are false. The model does nothing.") @@ -159,10 +166,11 @@ def __init__(self, withMean=False, withStd=True): def fit(self, dataset): """ - Computes the mean and variance and stores as a model to be used for later scaling. + Computes the mean and variance and stores as a model to be used + for later scaling. - :param data: The data used to compute the mean and variance to build - the transformation model. + :param data: The data used to compute the mean and variance + to build the transformation model. :return: a StandardScalarModel """ dataset = dataset.map(_convert_to_vector) @@ -174,7 +182,8 @@ class HashingTF(object): """ .. note:: Experimental - Maps a sequence of terms to their term frequencies using the hashing trick. + Maps a sequence of terms to their term frequencies using the hashing + trick. Note: the terms must be hashable (can not be dict/set/list...). @@ -195,8 +204,9 @@ def indexOf(self, term): def transform(self, document): """ - Transforms the input document (list of terms) to term frequency vectors, - or transform the RDD of document to RDD of term frequency vectors. + Transforms the input document (list of terms) to term frequency + vectors, or transform the RDD of document to RDD of term + frequency vectors. """ if isinstance(document, RDD): return document.map(self.transform) @@ -220,7 +230,12 @@ def transform(self, x): the terms which occur in fewer than `minDocFreq` documents will have an entry of 0. - :param x: an RDD of term frequency vectors or a term frequency vector + Note: In Python, transform cannot currently be used within + an RDD transformation or action. + Call transform directly on the RDD instead. + + :param x: an RDD of term frequency vectors or a term frequency + vector :return: an RDD of TF-IDF vectors or a TF-IDF vector """ if isinstance(x, RDD): @@ -241,9 +256,9 @@ class IDF(object): of documents that contain term `t`. This implementation supports filtering out terms which do not appear - in a minimum number of documents (controlled by the variable `minDocFreq`). - For terms that are not in at least `minDocFreq` documents, the IDF is - found as 0, resulting in TF-IDFs of 0. + in a minimum number of documents (controlled by the variable + `minDocFreq`). For terms that are not in at least `minDocFreq` + documents, the IDF is found as 0, resulting in TF-IDFs of 0. >>> n = 4 >>> freqs = [Vectors.sparse(n, (1, 3), (1.0, 2.0)), @@ -325,15 +340,16 @@ class Word2Vec(object): The vector representation can be used as features in natural language processing and machine learning algorithms. - We used skip-gram model in our implementation and hierarchical softmax - method to train the model. The variable names in the implementation - matches the original C implementation. + We used skip-gram model in our implementation and hierarchical + softmax method to train the model. The variable names in the + implementation matches the original C implementation. - For original C implementation, see https://code.google.com/p/word2vec/ + For original C implementation, + see https://code.google.com/p/word2vec/ For research papers, see Efficient Estimation of Word Representations in Vector Space - and - Distributed Representations of Words and Phrases and their Compositionality. + and Distributed Representations of Words and Phrases and their + Compositionality. >>> sentence = "a b " * 100 + "a c " * 10 >>> localDoc = [sentence, sentence] @@ -374,15 +390,16 @@ def setLearningRate(self, learningRate): def setNumPartitions(self, numPartitions): """ - Sets number of partitions (default: 1). Use a small number for accuracy. + Sets number of partitions (default: 1). Use a small number for + accuracy. """ self.numPartitions = numPartitions return self def setNumIterations(self, numIterations): """ - Sets number of iterations (default: 1), which should be smaller than or equal to number of - partitions. + Sets number of iterations (default: 1), which should be smaller + than or equal to number of partitions. """ self.numIterations = numIterations return self diff --git a/python/pyspark/mllib/linalg.py b/python/pyspark/mllib/linalg.py index 7f21190ed8c25..597012b1c967c 100644 --- a/python/pyspark/mllib/linalg.py +++ b/python/pyspark/mllib/linalg.py @@ -29,7 +29,7 @@ import numpy as np -from pyspark.sql import UserDefinedType, StructField, StructType, ArrayType, DoubleType, \ +from pyspark.sql.types import UserDefinedType, StructField, StructType, ArrayType, DoubleType, \ IntegerType, ByteType diff --git a/python/pyspark/mllib/recommendation.py b/python/pyspark/mllib/recommendation.py index 97ec74eda0b71..0d99e6dedfad9 100644 --- a/python/pyspark/mllib/recommendation.py +++ b/python/pyspark/mllib/recommendation.py @@ -49,17 +49,17 @@ class MatrixFactorizationModel(JavaModelWrapper): >>> r3 = (2, 1, 2.0) >>> ratings = sc.parallelize([r1, r2, r3]) >>> model = ALS.trainImplicit(ratings, 1, seed=10) - >>> model.predict(2,2) - 0.4473... + >>> model.predict(2, 2) + 0.43... >>> testset = sc.parallelize([(1, 2), (1, 1)]) - >>> model = ALS.train(ratings, 1, seed=10) + >>> model = ALS.train(ratings, 2, seed=0) >>> model.predictAll(testset).collect() - [Rating(user=1, product=1, rating=1.0471...), Rating(user=1, product=2, rating=1.9679...)] + [Rating(user=1, product=1, rating=1.0...), Rating(user=1, product=2, rating=1.9...)] >>> model = ALS.train(ratings, 4, seed=10) >>> model.userFeatures().collect() - [(2, array('d', [...])), (1, array('d', [...]))] + [(1, array('d', [...])), (2, array('d', [...]))] >>> first_user = model.userFeatures().take(1)[0] >>> latents = first_user[1] @@ -67,7 +67,7 @@ class MatrixFactorizationModel(JavaModelWrapper): True >>> model.productFeatures().collect() - [(2, array('d', [...])), (1, array('d', [...]))] + [(1, array('d', [...])), (2, array('d', [...]))] >>> first_product = model.productFeatures().take(1)[0] >>> latents = first_product[1] @@ -76,11 +76,11 @@ class MatrixFactorizationModel(JavaModelWrapper): >>> model = ALS.train(ratings, 1, nonnegative=True, seed=10) >>> model.predict(2,2) - 3.735... + 3.8... >>> model = ALS.trainImplicit(ratings, 1, nonnegative=True, seed=10) >>> model.predict(2,2) - 0.4473... + 0.43... """ def predict(self, user, product): return self._java_model.predict(int(user), int(product)) diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py index 210060140fd91..66617abb85670 100644 --- a/python/pyspark/mllib/regression.py +++ b/python/pyspark/mllib/regression.py @@ -18,7 +18,7 @@ import numpy as np from numpy import array -from pyspark.mllib.common import callMLlibFunc +from pyspark.mllib.common import callMLlibFunc, inherit_doc from pyspark.mllib.linalg import SparseVector, _convert_to_vector __all__ = ['LabeledPoint', 'LinearModel', 'LinearRegressionModel', 'RidgeRegressionModel', @@ -31,8 +31,11 @@ class LabeledPoint(object): The features and labels of a data point. :param label: Label for this data point. - :param features: Vector of features for this point (NumPy array, list, - pyspark.mllib.linalg.SparseVector, or scipy.sparse column matrix) + :param features: Vector of features for this point (NumPy array, + list, pyspark.mllib.linalg.SparseVector, or scipy.sparse + column matrix) + + Note: 'label' and 'features' are accessible as class attributes. """ def __init__(self, label, features): @@ -69,6 +72,7 @@ def __repr__(self): return "(weights=%s, intercept=%r)" % (self._coeff, self._intercept) +@inherit_doc class LinearRegressionModelBase(LinearModel): """A linear regression model. @@ -89,6 +93,7 @@ def predict(self, x): return self.weights.dot(x) + self.intercept +@inherit_doc class LinearRegressionModel(LinearRegressionModelBase): """A linear regression model derived from a least-squares fit. @@ -162,7 +167,7 @@ def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, @param intercept: Boolean parameter which indicates the use or not of the augmented representation for training data (i.e. whether bias features - are activated or not). + are activated or not). (default: False) """ def train(rdd, i): return callMLlibFunc("trainLinearRegressionModelWithSGD", rdd, int(iterations), @@ -172,6 +177,7 @@ def train(rdd, i): return _regression_train_wrapper(train, LinearRegressionModel, data, initialWeights) +@inherit_doc class LassoModel(LinearRegressionModelBase): """A linear regression model derived from a least-squares fit with an @@ -218,6 +224,7 @@ def train(rdd, i): return _regression_train_wrapper(train, LassoModel, data, initialWeights) +@inherit_doc class RidgeRegressionModel(LinearRegressionModelBase): """A linear regression model derived from a least-squares fit with an diff --git a/python/pyspark/mllib/stat/__init__.py b/python/pyspark/mllib/stat/__init__.py new file mode 100644 index 0000000000000..e3e128513e0d7 --- /dev/null +++ b/python/pyspark/mllib/stat/__init__.py @@ -0,0 +1,27 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Python package for statistical functions in MLlib. +""" + +from pyspark.mllib.stat._statistics import * +from pyspark.mllib.stat.distribution import MultivariateGaussian +from pyspark.mllib.stat.test import ChiSqTestResult + +__all__ = ["Statistics", "MultivariateStatisticalSummary", "ChiSqTestResult", + "MultivariateGaussian"] diff --git a/python/pyspark/mllib/stat.py b/python/pyspark/mllib/stat/_statistics.py similarity index 88% rename from python/pyspark/mllib/stat.py rename to python/pyspark/mllib/stat/_statistics.py index c8af777a8b00d..218ac148ca992 100644 --- a/python/pyspark/mllib/stat.py +++ b/python/pyspark/mllib/stat/_statistics.py @@ -15,17 +15,14 @@ # limitations under the License. # -""" -Python package for statistical functions in MLlib. -""" - from pyspark import RDD from pyspark.mllib.common import callMLlibFunc, JavaModelWrapper from pyspark.mllib.linalg import Matrix, _convert_to_vector from pyspark.mllib.regression import LabeledPoint +from pyspark.mllib.stat.test import ChiSqTestResult -__all__ = ['MultivariateStatisticalSummary', 'ChiSqTestResult', 'Statistics'] +__all__ = ['MultivariateStatisticalSummary', 'Statistics'] class MultivariateStatisticalSummary(JavaModelWrapper): @@ -53,54 +50,6 @@ def min(self): return self.call("min").toArray() -class ChiSqTestResult(JavaModelWrapper): - """ - .. note:: Experimental - - Object containing the test results for the chi-squared hypothesis test. - """ - @property - def method(self): - """ - Name of the test method - """ - return self._java_model.method() - - @property - def pValue(self): - """ - The probability of obtaining a test statistic result at least as - extreme as the one that was actually observed, assuming that the - null hypothesis is true. - """ - return self._java_model.pValue() - - @property - def degreesOfFreedom(self): - """ - Returns the degree(s) of freedom of the hypothesis test. - Return type should be Number(e.g. Int, Double) or tuples of Numbers. - """ - return self._java_model.degreesOfFreedom() - - @property - def statistic(self): - """ - Test statistic. - """ - return self._java_model.statistic() - - @property - def nullHypothesis(self): - """ - Null hypothesis of the test. - """ - return self._java_model.nullHypothesis() - - def __str__(self): - return self._java_model.toString() - - class Statistics(object): @staticmethod diff --git a/python/pyspark/mllib/stat/distribution.py b/python/pyspark/mllib/stat/distribution.py new file mode 100644 index 0000000000000..46f7a1d2f277a --- /dev/null +++ b/python/pyspark/mllib/stat/distribution.py @@ -0,0 +1,32 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from collections import namedtuple + +__all__ = ['MultivariateGaussian'] + + +class MultivariateGaussian(namedtuple('MultivariateGaussian', ['mu', 'sigma'])): + + """Represents a (mu, sigma) tuple + + >>> m = MultivariateGaussian(Vectors.dense([11,12]),DenseMatrix(2, 2, (1.0, 3.0, 5.0, 2.0))) + >>> (m.mu, m.sigma.toArray()) + (DenseVector([11.0, 12.0]), array([[ 1., 5.],[ 3., 2.]])) + >>> (m[0], m[1]) + (DenseVector([11.0, 12.0]), array([[ 1., 5.],[ 3., 2.]])) + """ diff --git a/python/pyspark/mllib/stat/test.py b/python/pyspark/mllib/stat/test.py new file mode 100644 index 0000000000000..762506e952b43 --- /dev/null +++ b/python/pyspark/mllib/stat/test.py @@ -0,0 +1,69 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from pyspark.mllib.common import JavaModelWrapper + + +__all__ = ["ChiSqTestResult"] + + +class ChiSqTestResult(JavaModelWrapper): + """ + .. note:: Experimental + + Object containing the test results for the chi-squared hypothesis test. + """ + @property + def method(self): + """ + Name of the test method + """ + return self._java_model.method() + + @property + def pValue(self): + """ + The probability of obtaining a test statistic result at least as + extreme as the one that was actually observed, assuming that the + null hypothesis is true. + """ + return self._java_model.pValue() + + @property + def degreesOfFreedom(self): + """ + Returns the degree(s) of freedom of the hypothesis test. + Return type should be Number(e.g. Int, Double) or tuples of Numbers. + """ + return self._java_model.degreesOfFreedom() + + @property + def statistic(self): + """ + Test statistic. + """ + return self._java_model.statistic() + + @property + def nullHypothesis(self): + """ + Null hypothesis of the test. + """ + return self._java_model.nullHypothesis() + + def __str__(self): + return self._java_model.toString() diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index 140c22b5fd4e8..06207a076eece 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -140,7 +140,7 @@ class ListTests(PySparkTestCase): as NumPy arrays. """ - def test_clustering(self): + def test_kmeans(self): from pyspark.mllib.clustering import KMeans data = [ [0, 1.1], @@ -152,9 +152,50 @@ def test_clustering(self): self.assertEquals(clusters.predict(data[0]), clusters.predict(data[1])) self.assertEquals(clusters.predict(data[2]), clusters.predict(data[3])) + def test_kmeans_deterministic(self): + from pyspark.mllib.clustering import KMeans + X = range(0, 100, 10) + Y = range(0, 100, 10) + data = [[x, y] for x, y in zip(X, Y)] + clusters1 = KMeans.train(self.sc.parallelize(data), + 3, initializationMode="k-means||", seed=42) + clusters2 = KMeans.train(self.sc.parallelize(data), + 3, initializationMode="k-means||", seed=42) + centers1 = clusters1.centers + centers2 = clusters2.centers + for c1, c2 in zip(centers1, centers2): + # TODO: Allow small numeric difference. + self.assertTrue(array_equal(c1, c2)) + + def test_gmm(self): + from pyspark.mllib.clustering import GaussianMixture + data = self.sc.parallelize([ + [1, 2], + [8, 9], + [-4, -3], + [-6, -7], + ]) + clusters = GaussianMixture.train(data, 2, convergenceTol=0.001, + maxIterations=100, seed=56) + labels = clusters.predict(data).collect() + self.assertEquals(labels[0], labels[1]) + self.assertEquals(labels[2], labels[3]) + + def test_gmm_deterministic(self): + from pyspark.mllib.clustering import GaussianMixture + x = range(0, 100, 10) + y = range(0, 100, 10) + data = self.sc.parallelize([[a, b] for a, b in zip(x, y)]) + clusters1 = GaussianMixture.train(data, 5, convergenceTol=0.001, + maxIterations=100, seed=63) + clusters2 = GaussianMixture.train(data, 5, convergenceTol=0.001, + maxIterations=100, seed=63) + for c1, c2 in zip(clusters1.weights, clusters2.weights): + self.assertEquals(round(c1, 7), round(c2, 7)) + def test_classification(self): from pyspark.mllib.classification import LogisticRegressionWithSGD, SVMWithSGD, NaiveBayes - from pyspark.mllib.tree import DecisionTree + from pyspark.mllib.tree import DecisionTree, RandomForest, GradientBoostedTrees data = [ LabeledPoint(0.0, [1, 0, 0]), LabeledPoint(1.0, [0, 1, 1]), @@ -183,18 +224,31 @@ def test_classification(self): self.assertTrue(nb_model.predict(features[3]) > 0) categoricalFeaturesInfo = {0: 3} # feature 0 has 3 categories - dt_model = \ - DecisionTree.trainClassifier(rdd, numClasses=2, - categoricalFeaturesInfo=categoricalFeaturesInfo) + dt_model = DecisionTree.trainClassifier( + rdd, numClasses=2, categoricalFeaturesInfo=categoricalFeaturesInfo) self.assertTrue(dt_model.predict(features[0]) <= 0) self.assertTrue(dt_model.predict(features[1]) > 0) self.assertTrue(dt_model.predict(features[2]) <= 0) self.assertTrue(dt_model.predict(features[3]) > 0) + rf_model = RandomForest.trainClassifier( + rdd, numClasses=2, categoricalFeaturesInfo=categoricalFeaturesInfo, numTrees=100) + self.assertTrue(rf_model.predict(features[0]) <= 0) + self.assertTrue(rf_model.predict(features[1]) > 0) + self.assertTrue(rf_model.predict(features[2]) <= 0) + self.assertTrue(rf_model.predict(features[3]) > 0) + + gbt_model = GradientBoostedTrees.trainClassifier( + rdd, categoricalFeaturesInfo=categoricalFeaturesInfo) + self.assertTrue(gbt_model.predict(features[0]) <= 0) + self.assertTrue(gbt_model.predict(features[1]) > 0) + self.assertTrue(gbt_model.predict(features[2]) <= 0) + self.assertTrue(gbt_model.predict(features[3]) > 0) + def test_regression(self): from pyspark.mllib.regression import LinearRegressionWithSGD, LassoWithSGD, \ RidgeRegressionWithSGD - from pyspark.mllib.tree import DecisionTree + from pyspark.mllib.tree import DecisionTree, RandomForest, GradientBoostedTrees data = [ LabeledPoint(-1.0, [0, -1]), LabeledPoint(1.0, [0, 1]), @@ -223,13 +277,27 @@ def test_regression(self): self.assertTrue(rr_model.predict(features[3]) > 0) categoricalFeaturesInfo = {0: 2} # feature 0 has 2 categories - dt_model = \ - DecisionTree.trainRegressor(rdd, categoricalFeaturesInfo=categoricalFeaturesInfo) + dt_model = DecisionTree.trainRegressor( + rdd, categoricalFeaturesInfo=categoricalFeaturesInfo) self.assertTrue(dt_model.predict(features[0]) <= 0) self.assertTrue(dt_model.predict(features[1]) > 0) self.assertTrue(dt_model.predict(features[2]) <= 0) self.assertTrue(dt_model.predict(features[3]) > 0) + rf_model = RandomForest.trainRegressor( + rdd, categoricalFeaturesInfo=categoricalFeaturesInfo, numTrees=100, seed=1) + self.assertTrue(rf_model.predict(features[0]) <= 0) + self.assertTrue(rf_model.predict(features[1]) > 0) + self.assertTrue(rf_model.predict(features[2]) <= 0) + self.assertTrue(rf_model.predict(features[3]) > 0) + + gbt_model = GradientBoostedTrees.trainRegressor( + rdd, categoricalFeaturesInfo=categoricalFeaturesInfo) + self.assertTrue(gbt_model.predict(features[0]) <= 0) + self.assertTrue(gbt_model.predict(features[1]) > 0) + self.assertTrue(gbt_model.predict(features[2]) <= 0) + self.assertTrue(gbt_model.predict(features[3]) > 0) + class StatTests(PySparkTestCase): # SPARK-4023 @@ -267,7 +335,7 @@ def test_infer_schema(self): sqlCtx = SQLContext(self.sc) rdd = self.sc.parallelize([LabeledPoint(1.0, self.dv1), LabeledPoint(0.0, self.sv1)]) srdd = sqlCtx.inferSchema(rdd) - schema = srdd.schema() + schema = srdd.schema field = [f for f in schema.fields if f.name == "features"][0] self.assertEqual(field.dataType, self.udt) vectors = srdd.map(lambda p: p.features).collect() diff --git a/python/pyspark/mllib/tree.py b/python/pyspark/mllib/tree.py index 66702478474dc..73618f0449ad4 100644 --- a/python/pyspark/mllib/tree.py +++ b/python/pyspark/mllib/tree.py @@ -20,25 +20,66 @@ import random from pyspark import SparkContext, RDD -from pyspark.mllib.common import callMLlibFunc, JavaModelWrapper +from pyspark.mllib.common import callMLlibFunc, inherit_doc, JavaModelWrapper from pyspark.mllib.linalg import _convert_to_vector from pyspark.mllib.regression import LabeledPoint -__all__ = ['DecisionTreeModel', 'DecisionTree', 'RandomForestModel', 'RandomForest'] +__all__ = ['DecisionTreeModel', 'DecisionTree', 'RandomForestModel', + 'RandomForest', 'GradientBoostedTreesModel', 'GradientBoostedTrees'] -class DecisionTreeModel(JavaModelWrapper): +class TreeEnsembleModel(JavaModelWrapper): + def predict(self, x): + """ + Predict values for a single data point or an RDD of points using + the model trained. + + Note: In Python, predict cannot currently be used within an RDD + transformation or action. + Call predict directly on the RDD instead. + """ + if isinstance(x, RDD): + return self.call("predict", x.map(_convert_to_vector)) + + else: + return self.call("predict", _convert_to_vector(x)) + + def numTrees(self): + """ + Get number of trees in ensemble. + """ + return self.call("numTrees") + + def totalNumNodes(self): + """ + Get total number of nodes, summed over all trees in the + ensemble. + """ + return self.call("totalNumNodes") + + def __repr__(self): + """ Summary of model """ + return self._java_model.toString() + + def toDebugString(self): + """ Full model """ + return self._java_model.toDebugString() + +class DecisionTreeModel(JavaModelWrapper): """ - A decision tree model for classification or regression. + .. note:: Experimental - EXPERIMENTAL: This is an experimental API. - It will probably be modified in future. + A decision tree model for classification or regression. """ def predict(self, x): """ Predict the label of one or more examples. + Note: In Python, predict cannot currently be used within an RDD + transformation or action. + Call predict directly on the RDD instead. + :param x: Data point (feature vector), or an RDD of data points (feature vectors). """ @@ -64,12 +105,11 @@ def toDebugString(self): class DecisionTree(object): - """ - Learning algorithm for a decision tree model for classification or regression. + .. note:: Experimental - EXPERIMENTAL: This is an experimental API. - It will probably be modified in future. + Learning algorithm for a decision tree model for classification or + regression. """ @classmethod @@ -146,17 +186,17 @@ def trainRegressor(cls, data, categoricalFeaturesInfo, :param data: Training data: RDD of LabeledPoint. Labels are real numbers. - :param categoricalFeaturesInfo: Map from categorical feature index - to number of categories. - Any feature not in this map - is treated as continuous. + :param categoricalFeaturesInfo: Map from categorical feature + index to number of categories. + Any feature not in this map is treated as continuous. :param impurity: Supported values: "variance" :param maxDepth: Max depth of tree. - E.g., depth 0 means 1 leaf node. - Depth 1 means 1 internal node + 2 leaf nodes. - :param maxBins: Number of bins used for finding splits at each node. - :param minInstancesPerNode: Min number of instances required at child - nodes to create the parent split + E.g., depth 0 means 1 leaf node. + Depth 1 means 1 internal node + 2 leaf nodes. + :param maxBins: Number of bins used for finding splits at each + node. + :param minInstancesPerNode: Min number of instances required at + child nodes to create the parent split :param minInfoGain: Min info gain required to create a split :return: DecisionTreeModel @@ -186,51 +226,21 @@ def trainRegressor(cls, data, categoricalFeaturesInfo, impurity, maxDepth, maxBins, minInstancesPerNode, minInfoGain) -class RandomForestModel(JavaModelWrapper): +@inherit_doc +class RandomForestModel(TreeEnsembleModel): """ - Represents a random forest model. + .. note:: Experimental - EXPERIMENTAL: This is an experimental API. - It will probably be modified in future. + Represents a random forest model. """ - def predict(self, x): - """ - Predict values for a single data point or an RDD of points using - the model trained. - """ - if isinstance(x, RDD): - return self.call("predict", x.map(_convert_to_vector)) - - else: - return self.call("predict", _convert_to_vector(x)) - - def numTrees(self): - """ - Get number of trees in forest. - """ - return self.call("numTrees") - - def totalNumNodes(self): - """ - Get total number of nodes, summed over all trees in the forest. - """ - return self.call("totalNumNodes") - - def __repr__(self): - """ Summary of model """ - return self._java_model.toString() - - def toDebugString(self): - """ Full model """ - return self._java_model.toDebugString() class RandomForest(object): """ - Learning algorithm for a random forest model for classification or regression. + .. note:: Experimental - EXPERIMENTAL: This is an experimental API. - It will probably be modified in future. + Learning algorithm for a random forest model for classification or + regression. """ supportedFeatureSubsetStrategies = ("auto", "all", "sqrt", "log2", "onethird") @@ -257,26 +267,33 @@ def trainClassifier(cls, data, numClasses, categoricalFeaturesInfo, numTrees, Method to train a decision tree model for binary or multiclass classification. - :param data: Training dataset: RDD of LabeledPoint. Labels should take - values {0, 1, ..., numClasses-1}. + :param data: Training dataset: RDD of LabeledPoint. Labels + should take values {0, 1, ..., numClasses-1}. :param numClasses: number of classes for classification. - :param categoricalFeaturesInfo: Map storing arity of categorical features. - E.g., an entry (n -> k) indicates that feature n is categorical - with k categories indexed from 0: {0, 1, ..., k-1}. + :param categoricalFeaturesInfo: Map storing arity of categorical + features. E.g., an entry (n -> k) indicates that + feature n is categorical with k categories indexed + from 0: {0, 1, ..., k-1}. :param numTrees: Number of trees in the random forest. - :param featureSubsetStrategy: Number of features to consider for splits at - each node. - Supported: "auto" (default), "all", "sqrt", "log2", "onethird". - If "auto" is set, this parameter is set based on numTrees: - if numTrees == 1, set to "all"; - if numTrees > 1 (forest) set to "sqrt". - :param impurity: Criterion used for information gain calculation. + :param featureSubsetStrategy: Number of features to consider for + splits at each node. + Supported: "auto" (default), "all", "sqrt", "log2", + "onethird". + If "auto" is set, this parameter is set based on + numTrees: + if numTrees == 1, set to "all"; + if numTrees > 1 (forest) set to "sqrt". + :param impurity: Criterion used for information gain + calculation. Supported values: "gini" (recommended) or "entropy". - :param maxDepth: Maximum depth of the tree. E.g., depth 0 means 1 leaf node; - depth 1 means 1 internal node + 2 leaf nodes. (default: 4) - :param maxBins: maximum number of bins used for splitting features + :param maxDepth: Maximum depth of the tree. + E.g., depth 0 means 1 leaf node; depth 1 means + 1 internal node + 2 leaf nodes. (default: 4) + :param maxBins: maximum number of bins used for splitting + features (default: 100) - :param seed: Random seed for bootstrapping and choosing feature subsets. + :param seed: Random seed for bootstrapping and choosing feature + subsets. :return: RandomForestModel that can be used for prediction Example usage: @@ -338,19 +355,24 @@ def trainRegressor(cls, data, categoricalFeaturesInfo, numTrees, featureSubsetSt {0, 1, ..., k-1}. :param numTrees: Number of trees in the random forest. :param featureSubsetStrategy: Number of features to consider for - splits at each node. - Supported: "auto" (default), "all", "sqrt", "log2", "onethird". - If "auto" is set, this parameter is set based on numTrees: - if numTrees == 1, set to "all"; - if numTrees > 1 (forest) set to "onethird" for regression. - :param impurity: Criterion used for information gain calculation. - Supported values: "variance". - :param maxDepth: Maximum depth of the tree. E.g., depth 0 means 1 - leaf node; depth 1 means 1 internal node + 2 leaf nodes. - (default: 4) - :param maxBins: maximum number of bins used for splitting features - (default: 100) - :param seed: Random seed for bootstrapping and choosing feature subsets. + splits at each node. + Supported: "auto" (default), "all", "sqrt", "log2", + "onethird". + If "auto" is set, this parameter is set based on + numTrees: + if numTrees == 1, set to "all"; + if numTrees > 1 (forest) set to "onethird" for + regression. + :param impurity: Criterion used for information gain + calculation. + Supported values: "variance". + :param maxDepth: Maximum depth of the tree. E.g., depth 0 means + 1 leaf node; depth 1 means 1 internal node + 2 leaf + nodes. (default: 4) + :param maxBins: maximum number of bins used for splitting + features (default: 100) + :param seed: Random seed for bootstrapping and choosing feature + subsets. :return: RandomForestModel that can be used for prediction Example usage: @@ -383,6 +405,147 @@ def trainRegressor(cls, data, categoricalFeaturesInfo, numTrees, featureSubsetSt featureSubsetStrategy, impurity, maxDepth, maxBins, seed) +@inherit_doc +class GradientBoostedTreesModel(TreeEnsembleModel): + """ + .. note:: Experimental + + Represents a gradient-boosted tree model. + """ + + +class GradientBoostedTrees(object): + """ + .. note:: Experimental + + Learning algorithm for a gradient boosted trees model for + classification or regression. + """ + + @classmethod + def _train(cls, data, algo, categoricalFeaturesInfo, + loss, numIterations, learningRate, maxDepth): + first = data.first() + assert isinstance(first, LabeledPoint), "the data should be RDD of LabeledPoint" + model = callMLlibFunc("trainGradientBoostedTreesModel", data, algo, categoricalFeaturesInfo, + loss, numIterations, learningRate, maxDepth) + return GradientBoostedTreesModel(model) + + @classmethod + def trainClassifier(cls, data, categoricalFeaturesInfo, + loss="logLoss", numIterations=100, learningRate=0.1, maxDepth=3): + """ + Method to train a gradient-boosted trees model for + classification. + + :param data: Training dataset: RDD of LabeledPoint. + Labels should take values {0, 1}. + :param categoricalFeaturesInfo: Map storing arity of categorical + features. E.g., an entry (n -> k) indicates that feature + n is categorical with k categories indexed from 0: + {0, 1, ..., k-1}. + :param loss: Loss function used for minimization during gradient + boosting. Supported: {"logLoss" (default), + "leastSquaresError", "leastAbsoluteError"}. + :param numIterations: Number of iterations of boosting. + (default: 100) + :param learningRate: Learning rate for shrinking the + contribution of each estimator. The learning rate + should be between in the interval (0, 1]. + (default: 0.1) + :param maxDepth: Maximum depth of the tree. E.g., depth 0 means + 1 leaf node; depth 1 means 1 internal node + 2 leaf + nodes. (default: 3) + :return: GradientBoostedTreesModel that can be used for + prediction + + Example usage: + + >>> from pyspark.mllib.regression import LabeledPoint + >>> from pyspark.mllib.tree import GradientBoostedTrees + >>> + >>> data = [ + ... LabeledPoint(0.0, [0.0]), + ... LabeledPoint(0.0, [1.0]), + ... LabeledPoint(1.0, [2.0]), + ... LabeledPoint(1.0, [3.0]) + ... ] + >>> + >>> model = GradientBoostedTrees.trainClassifier(sc.parallelize(data), {}) + >>> model.numTrees() + 100 + >>> model.totalNumNodes() + 300 + >>> print model, # it already has newline + TreeEnsembleModel classifier with 100 trees + >>> model.predict([2.0]) + 1.0 + >>> model.predict([0.0]) + 0.0 + >>> rdd = sc.parallelize([[2.0], [0.0]]) + >>> model.predict(rdd).collect() + [1.0, 0.0] + """ + return cls._train(data, "classification", categoricalFeaturesInfo, + loss, numIterations, learningRate, maxDepth) + + @classmethod + def trainRegressor(cls, data, categoricalFeaturesInfo, + loss="leastSquaresError", numIterations=100, learningRate=0.1, maxDepth=3): + """ + Method to train a gradient-boosted trees model for regression. + + :param data: Training dataset: RDD of LabeledPoint. Labels are + real numbers. + :param categoricalFeaturesInfo: Map storing arity of categorical + features. E.g., an entry (n -> k) indicates that feature + n is categorical with k categories indexed from 0: + {0, 1, ..., k-1}. + :param loss: Loss function used for minimization during gradient + boosting. Supported: {"logLoss" (default), + "leastSquaresError", "leastAbsoluteError"}. + :param numIterations: Number of iterations of boosting. + (default: 100) + :param learningRate: Learning rate for shrinking the + contribution of each estimator. The learning rate + should be between in the interval (0, 1]. + (default: 0.1) + :param maxDepth: Maximum depth of the tree. E.g., depth 0 means + 1 leaf node; depth 1 means 1 internal node + 2 leaf + nodes. (default: 3) + :return: GradientBoostedTreesModel that can be used for + prediction + + Example usage: + + >>> from pyspark.mllib.regression import LabeledPoint + >>> from pyspark.mllib.tree import GradientBoostedTrees + >>> from pyspark.mllib.linalg import SparseVector + >>> + >>> sparse_data = [ + ... LabeledPoint(0.0, SparseVector(2, {0: 1.0})), + ... LabeledPoint(1.0, SparseVector(2, {1: 1.0})), + ... LabeledPoint(0.0, SparseVector(2, {0: 1.0})), + ... LabeledPoint(1.0, SparseVector(2, {1: 2.0})) + ... ] + >>> + >>> model = GradientBoostedTrees.trainRegressor(sc.parallelize(sparse_data), {}) + >>> model.numTrees() + 100 + >>> model.totalNumNodes() + 102 + >>> model.predict(SparseVector(2, {1: 1.0})) + 1.0 + >>> model.predict(SparseVector(2, {0: 1.0})) + 0.0 + >>> rdd = sc.parallelize([[0.0, 1.0], [1.0, 0.0]]) + >>> model.predict(rdd).collect() + [1.0, 0.0] + """ + return cls._train(data, "regression", categoricalFeaturesInfo, + loss, numIterations, learningRate, maxDepth) + + def _test(): import doctest globs = globals().copy() diff --git a/python/pyspark/profiler.py b/python/pyspark/profiler.py new file mode 100644 index 0000000000000..4408996db0790 --- /dev/null +++ b/python/pyspark/profiler.py @@ -0,0 +1,172 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import cProfile +import pstats +import os +import atexit + +from pyspark.accumulators import AccumulatorParam + + +class ProfilerCollector(object): + """ + This class keeps track of different profilers on a per + stage basis. Also this is used to create new profilers for + the different stages. + """ + + def __init__(self, profiler_cls, dump_path=None): + self.profiler_cls = profiler_cls + self.profile_dump_path = dump_path + self.profilers = [] + + def new_profiler(self, ctx): + """ Create a new profiler using class `profiler_cls` """ + return self.profiler_cls(ctx) + + def add_profiler(self, id, profiler): + """ Add a profiler for RDD `id` """ + if not self.profilers: + if self.profile_dump_path: + atexit.register(self.dump_profiles, self.profile_dump_path) + else: + atexit.register(self.show_profiles) + + self.profilers.append([id, profiler, False]) + + def dump_profiles(self, path): + """ Dump the profile stats into directory `path` """ + for id, profiler, _ in self.profilers: + profiler.dump(id, path) + self.profilers = [] + + def show_profiles(self): + """ Print the profile stats to stdout """ + for i, (id, profiler, showed) in enumerate(self.profilers): + if not showed and profiler: + profiler.show(id) + # mark it as showed + self.profilers[i][2] = True + + +class Profiler(object): + """ + .. note:: DeveloperApi + + PySpark supports custom profilers, this is to allow for different profilers to + be used as well as outputting to different formats than what is provided in the + BasicProfiler. + + A custom profiler has to define or inherit the following methods: + profile - will produce a system profile of some sort. + stats - return the collected stats. + dump - dumps the profiles to a path + add - adds a profile to the existing accumulated profile + + The profiler class is chosen when creating a SparkContext + + >>> from pyspark import SparkConf, SparkContext + >>> from pyspark import BasicProfiler + >>> class MyCustomProfiler(BasicProfiler): + ... def show(self, id): + ... print "My custom profiles for RDD:%s" % id + ... + >>> conf = SparkConf().set("spark.python.profile", "true") + >>> sc = SparkContext('local', 'test', conf=conf, profiler_cls=MyCustomProfiler) + >>> sc.parallelize(list(range(1000))).map(lambda x: 2 * x).take(10) + [0, 2, 4, 6, 8, 10, 12, 14, 16, 18] + >>> sc.show_profiles() + My custom profiles for RDD:1 + My custom profiles for RDD:2 + >>> sc.stop() + """ + + def __init__(self, ctx): + pass + + def profile(self, func): + """ Do profiling on the function `func`""" + raise NotImplemented + + def stats(self): + """ Return the collected profiling stats (pstats.Stats)""" + raise NotImplemented + + def show(self, id): + """ Print the profile stats to stdout, id is the RDD id """ + stats = self.stats() + if stats: + print "=" * 60 + print "Profile of RDD" % id + print "=" * 60 + stats.sort_stats("time", "cumulative").print_stats() + + def dump(self, id, path): + """ Dump the profile into path, id is the RDD id """ + if not os.path.exists(path): + os.makedirs(path) + stats = self.stats() + if stats: + p = os.path.join(path, "rdd_%d.pstats" % id) + stats.dump_stats(p) + + +class PStatsParam(AccumulatorParam): + """PStatsParam is used to merge pstats.Stats""" + + @staticmethod + def zero(value): + return None + + @staticmethod + def addInPlace(value1, value2): + if value1 is None: + return value2 + value1.add(value2) + return value1 + + +class BasicProfiler(Profiler): + """ + BasicProfiler is the default profiler, which is implemented based on + cProfile and Accumulator + """ + def __init__(self, ctx): + Profiler.__init__(self, ctx) + # Creates a new accumulator for combining the profiles of different + # partitions of a stage + self._accumulator = ctx.accumulator(None, PStatsParam) + + def profile(self, func): + """ Runs and profiles the method to_profile passed in. A profile object is returned. """ + pr = cProfile.Profile() + pr.runcall(func) + st = pstats.Stats(pr) + st.stream = None # make it picklable + st.strip_dirs() + + # Adds a new profile to the existing accumulated value + self._accumulator.add(st) + + def stats(self): + return self._accumulator.value + + +if __name__ == "__main__": + import doctest + doctest.testmod() diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index c1120cf781e5e..cb12fed98c53d 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -29,9 +29,8 @@ import heapq import bisect import random -from math import sqrt, log, isinf, isnan +from math import sqrt, log, isinf, isnan, pow, ceil -from pyspark.accumulators import PStatsParam from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \ BatchedSerializer, CloudPickleSerializer, PairDeserializer, \ PickleSerializer, pack_long, AutoBatchedSerializer @@ -112,6 +111,19 @@ def _parse_memory(s): return int(float(s[:-1]) * units[s[-1].lower()]) +class Partitioner(object): + def __init__(self, numPartitions, partitionFunc): + self.numPartitions = numPartitions + self.partitionFunc = partitionFunc + + def __eq__(self, other): + return (isinstance(other, Partitioner) and self.numPartitions == other.numPartitions + and self.partitionFunc == other.partitionFunc) + + def __call__(self, k): + return self.partitionFunc(k) % self.numPartitions + + class RDD(object): """ @@ -127,7 +139,7 @@ def __init__(self, jrdd, ctx, jrdd_deserializer=AutoBatchedSerializer(PickleSeri self.ctx = ctx self._jrdd_deserializer = jrdd_deserializer self._id = jrdd.id() - self._partitionFunc = None + self.partitioner = None def _pickled(self): return self._reserialize(AutoBatchedSerializer(PickleSerializer())) @@ -141,6 +153,17 @@ def id(self): def __repr__(self): return self._jrdd.toString() + def __getnewargs__(self): + # This method is called when attempting to pickle an RDD, which is always an error: + raise Exception( + "It appears that you are attempting to broadcast an RDD or reference an RDD from an " + "action or transformation. RDD transformations and actions can only be invoked by the " + "driver, not inside of other transformations; for example, " + "rdd1.map(lambda x: rdd2.values.count() * x) is invalid because the values " + "transformation and count action cannot be performed inside of the rdd1.map " + "transformation. For more information, see SPARK-5063." + ) + @property def context(self): """ @@ -440,14 +463,17 @@ def union(self, other): if self._jrdd_deserializer == other._jrdd_deserializer: rdd = RDD(self._jrdd.union(other._jrdd), self.ctx, self._jrdd_deserializer) - return rdd else: # These RDDs contain data in different serialized formats, so we # must normalize them to the default serializer. self_copy = self._reserialize() other_copy = other._reserialize() - return RDD(self_copy._jrdd.union(other_copy._jrdd), self.ctx, - self.ctx.serializer) + rdd = RDD(self_copy._jrdd.union(other_copy._jrdd), self.ctx, + self.ctx.serializer) + if (self.partitioner == other.partitioner and + self.getNumPartitions() == rdd.getNumPartitions()): + rdd.partitioner = self.partitioner + return rdd def intersection(self, other): """ @@ -716,6 +742,43 @@ def func(iterator): return reduce(f, vals) raise ValueError("Can not reduce() empty RDD") + def treeReduce(self, f, depth=2): + """ + Reduces the elements of this RDD in a multi-level tree pattern. + + :param depth: suggested depth of the tree (default: 2) + + >>> add = lambda x, y: x + y + >>> rdd = sc.parallelize([-5, -4, -3, -2, -1, 1, 2, 3, 4], 10) + >>> rdd.treeReduce(add) + -5 + >>> rdd.treeReduce(add, 1) + -5 + >>> rdd.treeReduce(add, 2) + -5 + >>> rdd.treeReduce(add, 5) + -5 + >>> rdd.treeReduce(add, 10) + -5 + """ + if depth < 1: + raise ValueError("Depth cannot be smaller than 1 but got %d." % depth) + + zeroValue = None, True # Use the second entry to indicate whether this is a dummy value. + + def op(x, y): + if x[1]: + return y + elif y[1]: + return x + else: + return f(x[0], y[0]), False + + reduced = self.map(lambda x: (x, False)).treeAggregate(zeroValue, op, op, depth) + if reduced[1]: + raise ValueError("Cannot reduce empty RDD.") + return reduced[0] + def fold(self, zeroValue, op): """ Aggregate the elements of each partition, and then the results for all @@ -767,6 +830,58 @@ def func(iterator): return self.mapPartitions(func).fold(zeroValue, combOp) + def treeAggregate(self, zeroValue, seqOp, combOp, depth=2): + """ + Aggregates the elements of this RDD in a multi-level tree + pattern. + + :param depth: suggested depth of the tree (default: 2) + + >>> add = lambda x, y: x + y + >>> rdd = sc.parallelize([-5, -4, -3, -2, -1, 1, 2, 3, 4], 10) + >>> rdd.treeAggregate(0, add, add) + -5 + >>> rdd.treeAggregate(0, add, add, 1) + -5 + >>> rdd.treeAggregate(0, add, add, 2) + -5 + >>> rdd.treeAggregate(0, add, add, 5) + -5 + >>> rdd.treeAggregate(0, add, add, 10) + -5 + """ + if depth < 1: + raise ValueError("Depth cannot be smaller than 1 but got %d." % depth) + + if self.getNumPartitions() == 0: + return zeroValue + + def aggregatePartition(iterator): + acc = zeroValue + for obj in iterator: + acc = seqOp(acc, obj) + yield acc + + partiallyAggregated = self.mapPartitions(aggregatePartition) + numPartitions = partiallyAggregated.getNumPartitions() + scale = max(int(ceil(pow(numPartitions, 1.0 / depth))), 2) + # If creating an extra level doesn't help reduce the wall-clock time, we stop the tree + # aggregation. + while numPartitions > scale + numPartitions / scale: + numPartitions /= scale + curNumPartitions = numPartitions + + def mapPartition(i, iterator): + for obj in iterator: + yield (i % curNumPartitions, obj) + + partiallyAggregated = partiallyAggregated \ + .mapPartitionsWithIndex(mapPartition) \ + .reduceByKey(combOp, curNumPartitions) \ + .values() + + return partiallyAggregated.reduce(combOp) + def max(self, key=None): """ Find the maximum item in this RDD. @@ -1130,6 +1245,18 @@ def first(self): return rs[0] raise ValueError("RDD is empty") + def isEmpty(self): + """ + Returns true if and only if the RDD contains no elements at all. Note that an RDD + may be empty even when it has at least 1 partition. + + >>> sc.parallelize([]).isEmpty() + True + >>> sc.parallelize([1]).isEmpty() + False + """ + return self._jrdd.partitions().size() == 0 or len(self.take(1)) == 0 + def saveAsNewAPIHadoopDataset(self, conf, keyConverter=None, valueConverter=None): """ Output a Python RDD of key-value pairs (of form C{RDD[(K, V)]}) to any Hadoop file @@ -1255,10 +1382,14 @@ def saveAsPickleFile(self, path, batchSize=10): ser = BatchedSerializer(PickleSerializer(), batchSize) self._reserialize(ser)._jrdd.saveAsObjectFile(path) - def saveAsTextFile(self, path): + def saveAsTextFile(self, path, compressionCodecClass=None): """ Save this RDD as a text file, using string representations of elements. + @param path: path to text file + @param compressionCodecClass: (None by default) string i.e. + "org.apache.hadoop.io.compress.GzipCodec" + >>> tempFile = NamedTemporaryFile(delete=True) >>> tempFile.close() >>> sc.parallelize(range(10)).saveAsTextFile(tempFile.name) @@ -1274,6 +1405,16 @@ def saveAsTextFile(self, path): >>> sc.parallelize(['', 'foo', '', 'bar', '']).saveAsTextFile(tempFile2.name) >>> ''.join(sorted(input(glob(tempFile2.name + "/part-0000*")))) '\\n\\n\\nbar\\nfoo\\n' + + Using compressionCodecClass + + >>> tempFile3 = NamedTemporaryFile(delete=True) + >>> tempFile3.close() + >>> codec = "org.apache.hadoop.io.compress.GzipCodec" + >>> sc.parallelize(['foo', 'bar']).saveAsTextFile(tempFile3.name, codec) + >>> from fileinput import input, hook_compressed + >>> ''.join(sorted(input(glob(tempFile3.name + "/part*.gz"), openhook=hook_compressed))) + 'bar\\nfoo\\n' """ def func(split, iterator): for x in iterator: @@ -1284,7 +1425,11 @@ def func(split, iterator): yield x keyed = self.mapPartitionsWithIndex(func) keyed._bypass_serializer = True - keyed._jrdd.map(self.ctx._jvm.BytesToString()).saveAsTextFile(path) + if compressionCodecClass: + compressionCodec = self.ctx._jvm.java.lang.Class.forName(compressionCodecClass) + keyed._jrdd.map(self.ctx._jvm.BytesToString()).saveAsTextFile(path, compressionCodec) + else: + keyed._jrdd.map(self.ctx._jvm.BytesToString()).saveAsTextFile(path) # Pair functions @@ -1459,6 +1604,9 @@ def partitionBy(self, numPartitions, partitionFunc=portable_hash): """ if numPartitions is None: numPartitions = self._defaultReducePartitions() + partitioner = Partitioner(numPartitions, partitionFunc) + if self.partitioner == partitioner: + return self # Transferring O(n) objects to Java is too expensive. # Instead, we'll form the hash buckets in Python, @@ -1503,18 +1651,16 @@ def add_shuffle_key(split, iterator): yield pack_long(split) yield outputSerializer.dumps(items) - keyed = self.mapPartitionsWithIndex(add_shuffle_key) + keyed = self.mapPartitionsWithIndex(add_shuffle_key, preservesPartitioning=True) keyed._bypass_serializer = True with SCCallSiteSync(self.context) as css: pairRDD = self.ctx._jvm.PairwiseRDD( keyed._jrdd.rdd()).asJavaPairRDD() - partitioner = self.ctx._jvm.PythonPartitioner(numPartitions, - id(partitionFunc)) - jrdd = pairRDD.partitionBy(partitioner).values() + jpartitioner = self.ctx._jvm.PythonPartitioner(numPartitions, + id(partitionFunc)) + jrdd = self.ctx._jvm.PythonRDD.valueOfPair(pairRDD.partitionBy(jpartitioner)) rdd = RDD(jrdd, self.ctx, BatchedSerializer(outputSerializer)) - # This is required so that id(partitionFunc) remains unique, - # even if partitionFunc is a lambda: - rdd._partitionFunc = partitionFunc + rdd.partitioner = partitioner return rdd # TODO: add control over map-side aggregation @@ -1560,7 +1706,7 @@ def combineLocally(iterator): merger.mergeValues(iterator) return merger.iteritems() - locally_combined = self.mapPartitions(combineLocally) + locally_combined = self.mapPartitions(combineLocally, preservesPartitioning=True) shuffled = locally_combined.partitionBy(numPartitions) def _mergeCombiners(iterator): @@ -1569,7 +1715,7 @@ def _mergeCombiners(iterator): merger.mergeCombiners(iterator) return merger.iteritems() - return shuffled.mapPartitions(_mergeCombiners, True) + return shuffled.mapPartitions(_mergeCombiners, preservesPartitioning=True) def aggregateByKey(self, zeroValue, seqFunc, combFunc, numPartitions=None): """ @@ -1611,8 +1757,8 @@ def groupByKey(self, numPartitions=None): Hash-partitions the resulting RDD with into numPartitions partitions. Note: If you are grouping in order to perform an aggregation (such as a - sum or average) over each key, using reduceByKey will provide much - better performance. + sum or average) over each key, using reduceByKey or aggregateByKey will + provide much better performance. >>> x = sc.parallelize([("a", 1), ("b", 1), ("a", 1)]) >>> map((lambda (x,y): (x, list(y))), sorted(x.groupByKey().collect())) @@ -1804,7 +1950,7 @@ def batch_as(rdd, batchSize): my_batch = get_batch_size(self._jrdd_deserializer) other_batch = get_batch_size(other._jrdd_deserializer) - if my_batch != other_batch: + if my_batch != other_batch or not my_batch: # use the smallest batchSize for both of them batchSize = min(my_batch, other_batch) if batchSize <= 0: @@ -1948,8 +2094,8 @@ def lookup(self, key): """ values = self.filter(lambda (k, v): k == key).values() - if self._partitionFunc is not None: - return self.ctx.runJob(values, lambda x: x, [self._partitionFunc(key)], False) + if self.partitioner is not None: + return self.ctx.runJob(values, lambda x: x, [self.partitioner(key)], False) return values.collect() @@ -1965,6 +2111,7 @@ def _to_java_object_rdd(self): def countApprox(self, timeout, confidence=0.95): """ .. note:: Experimental + Approximate version of count() that returns a potentially incomplete result within a timeout, even if not all tasks have finished. @@ -1978,6 +2125,7 @@ def countApprox(self, timeout, confidence=0.95): def sumApprox(self, timeout, confidence=0.95): """ .. note:: Experimental + Approximate operation to return the sum within a timeout or meet the confidence. @@ -1994,6 +2142,7 @@ def sumApprox(self, timeout, confidence=0.95): def meanApprox(self, timeout, confidence=0.95): """ .. note:: Experimental + Approximate operation to return the mean within a timeout or meet the confidence. @@ -2010,6 +2159,7 @@ def meanApprox(self, timeout, confidence=0.95): def countApproxDistinct(self, relativeSD=0.05): """ .. note:: Experimental + Return approximate number of distinct elements in the RDD. The algorithm used is based on streamlib's implementation of @@ -2036,6 +2186,39 @@ def countApproxDistinct(self, relativeSD=0.05): hashRDD = self.map(lambda x: portable_hash(x) & 0xFFFFFFFF) return hashRDD._to_java_object_rdd().countApproxDistinct(relativeSD) + def toLocalIterator(self): + """ + Return an iterator that contains all of the elements in this RDD. + The iterator will consume as much memory as the largest partition in this RDD. + >>> rdd = sc.parallelize(range(10)) + >>> [x for x in rdd.toLocalIterator()] + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] + """ + partitions = xrange(self.getNumPartitions()) + for partition in partitions: + rows = self.context.runJob(self, lambda x: x, [partition]) + for row in rows: + yield row + + +def _prepare_for_python_RDD(sc, command, obj=None): + # the serialized command will be compressed by broadcast + ser = CloudPickleSerializer() + pickled_command = ser.dumps(command) + if len(pickled_command) > (1 << 20): # 1M + broadcast = sc.broadcast(pickled_command) + pickled_command = ser.dumps(broadcast) + # tracking the life cycle by obj + if obj is not None: + obj._broadcast = broadcast + broadcast_vars = ListConverter().convert( + [x._jbroadcast for x in sc._pickled_broadcast_vars], + sc._gateway._gateway_client) + sc._pickled_broadcast_vars.clear() + env = MapConverter().convert(sc.environment, sc._gateway._gateway_client) + includes = ListConverter().convert(sc._python_includes, sc._gateway._gateway_client) + return pickled_command, broadcast_vars, env, includes + class PipelinedRDD(RDD): @@ -2081,7 +2264,7 @@ def pipeline_func(split, iterator): self._id = None self._jrdd_deserializer = self.ctx.serializer self._bypass_serializer = False - self._partitionFunc = prev._partitionFunc if self.preservesPartitioning else None + self.partitioner = prev.partitioner if self.preservesPartitioning else None self._broadcast = None def __del__(self): @@ -2095,34 +2278,25 @@ def _jrdd(self): return self._jrdd_val if self._bypass_serializer: self._jrdd_deserializer = NoOpSerializer() - enable_profile = self.ctx._conf.get("spark.python.profile", "false") == "true" - profileStats = self.ctx.accumulator(None, PStatsParam) if enable_profile else None - command = (self.func, profileStats, self._prev_jrdd_deserializer, + + if self.ctx.profiler_collector: + profiler = self.ctx.profiler_collector.new_profiler(self.ctx) + else: + profiler = None + + command = (self.func, profiler, self._prev_jrdd_deserializer, self._jrdd_deserializer) - # the serialized command will be compressed by broadcast - ser = CloudPickleSerializer() - pickled_command = ser.dumps(command) - if len(pickled_command) > (1 << 20): # 1M - self._broadcast = self.ctx.broadcast(pickled_command) - pickled_command = ser.dumps(self._broadcast) - broadcast_vars = ListConverter().convert( - [x._jbroadcast for x in self.ctx._pickled_broadcast_vars], - self.ctx._gateway._gateway_client) - self.ctx._pickled_broadcast_vars.clear() - env = MapConverter().convert(self.ctx.environment, - self.ctx._gateway._gateway_client) - includes = ListConverter().convert(self.ctx._python_includes, - self.ctx._gateway._gateway_client) + pickled_cmd, bvars, env, includes = _prepare_for_python_RDD(self.ctx, command, self) python_rdd = self.ctx._jvm.PythonRDD(self._prev_jrdd.rdd(), - bytearray(pickled_command), + bytearray(pickled_cmd), env, includes, self.preservesPartitioning, self.ctx.pythonExec, - broadcast_vars, self.ctx._javaAccumulator) + bvars, self.ctx._javaAccumulator) self._jrdd_val = python_rdd.asJavaRDD() - if enable_profile: + if profiler: self._id = self._jrdd_val.id() - self.ctx._add_profile(self._id, profileStats) + self.ctx.profiler_collector.add_profiler(self._id, profiler) return self._jrdd_val def id(self): diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index b8bda835174b2..0ffb41d02f6f6 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -70,6 +70,7 @@ class SpecialLengths(object): PYTHON_EXCEPTION_THROWN = -2 TIMING_DATA = -3 END_OF_STREAM = -4 + NULL = -5 class Serializer(object): @@ -133,6 +134,8 @@ def load_stream(self, stream): def _write_with_length(self, obj, stream): serialized = self.dumps(obj) + if serialized is None: + raise ValueError("serialized value should not be None") if len(serialized) > (1 << 31): raise ValueError("can not serialize object larger than 2G") write_int(len(serialized), stream) @@ -145,8 +148,10 @@ def _read_with_length(self, stream): length = read_int(stream) if length == SpecialLengths.END_OF_DATA_SECTION: raise EOFError + elif length == SpecialLengths.NULL: + return None obj = stream.read(length) - if obj == "": + if len(obj) < length: raise EOFError return self.loads(obj) @@ -484,6 +489,8 @@ def loads(self, stream): length = read_int(stream) if length == SpecialLengths.END_OF_DATA_SECTION: raise EOFError + elif length == SpecialLengths.NULL: + return None s = stream.read(length) return s.decode("utf-8") if self.use_unicode else s diff --git a/python/pyspark/shell.py b/python/pyspark/shell.py index 89cf76920e353..1a02fece9c5a5 100644 --- a/python/pyspark/shell.py +++ b/python/pyspark/shell.py @@ -31,13 +31,18 @@ import atexit import os import platform + +import py4j + import pyspark from pyspark.context import SparkContext +from pyspark.sql import SQLContext, HiveContext from pyspark.storagelevel import StorageLevel -# this is the equivalent of ADD_JARS -add_files = (os.environ.get("ADD_FILES").split(',') - if os.environ.get("ADD_FILES") is not None else None) +# this is the deprecated equivalent of ADD_JARS +add_files = None +if os.environ.get("ADD_FILES") is not None: + add_files = os.environ.get("ADD_FILES").split(',') if os.environ.get("SPARK_EXECUTOR_URI"): SparkContext.setSystemProperty("spark.executor.uri", os.environ["SPARK_EXECUTOR_URI"]) @@ -45,6 +50,13 @@ sc = SparkContext(appName="PySparkShell", pyFiles=add_files) atexit.register(lambda: sc.stop()) +try: + # Try to access HiveConf, it will raise exception if Hive is not added + sc._jvm.org.apache.hadoop.hive.conf.HiveConf() + sqlCtx = HiveContext(sc) +except py4j.protocol.Py4JError: + sqlCtx = SQLContext(sc) + print("""Welcome to ____ __ / __/__ ___ _____/ /__ @@ -56,9 +68,10 @@ platform.python_version(), platform.python_build()[0], platform.python_build()[1])) -print("SparkContext available as sc.") +print("SparkContext available as sc, %s available as sqlCtx." % sqlCtx.__class__.__name__) if add_files is not None: + print("Warning: ADD_FILES environment variable is deprecated, use --py-files argument instead") print("Adding files: [%s]" % ", ".join(add_files)) # The ./bin/pyspark script stores the old PYTHONSTARTUP value in OLD_PYTHONSTARTUP, diff --git a/python/pyspark/sql/__init__.py b/python/pyspark/sql/__init__.py new file mode 100644 index 0000000000000..b9ffd6945ea7e --- /dev/null +++ b/python/pyspark/sql/__init__.py @@ -0,0 +1,41 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +public classes of Spark SQL: + + - L{SQLContext} + Main entry point for SQL functionality. + - L{DataFrame} + A Resilient Distributed Dataset (RDD) with Schema information for the data contained. In + addition to normal RDD operations, DataFrames also support SQL. + - L{GroupedData} + - L{Column} + Column is a DataFrame with a single column. + - L{Row} + A Row of data returned by a Spark SQL query. + - L{HiveContext} + Main entry point for accessing data stored in Apache Hive.. +""" + +from pyspark.sql.context import SQLContext, HiveContext +from pyspark.sql.types import Row +from pyspark.sql.dataframe import DataFrame, GroupedData, Column, SchemaRDD + +__all__ = [ + 'SQLContext', 'HiveContext', 'DataFrame', 'GroupedData', 'Column', 'Row', +] diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py new file mode 100644 index 0000000000000..5d7aeb664cadf --- /dev/null +++ b/python/pyspark/sql/context.py @@ -0,0 +1,744 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import warnings +import json +from array import array +from itertools import imap + +from py4j.protocol import Py4JError +from py4j.java_collections import MapConverter + +from pyspark.rdd import RDD, _prepare_for_python_RDD +from pyspark.serializers import AutoBatchedSerializer, PickleSerializer +from pyspark.sql.types import StringType, StructType, _verify_type, \ + _infer_schema, _has_nulltype, _merge_type, _create_converter, _python_to_sql_converter +from pyspark.sql.dataframe import DataFrame + +try: + import pandas + has_pandas = True +except ImportError: + has_pandas = False + +__all__ = ["SQLContext", "HiveContext"] + + +def _monkey_patch_RDD(sqlCtx): + def toDF(self, schema=None, sampleRatio=None): + """ + Convert current :class:`RDD` into a :class:`DataFrame` + + This is a shorthand for `sqlCtx.createDataFrame(rdd, schema, sampleRatio)` + + :param schema: a StructType or list of names of columns + :param samplingRatio: the sample ratio of rows used for inferring + :return: a DataFrame + + >>> rdd.toDF().collect() + [Row(name=u'Alice', age=1)] + """ + return sqlCtx.createDataFrame(self, schema, sampleRatio) + + RDD.toDF = toDF + + +class SQLContext(object): + + """Main entry point for Spark SQL functionality. + + A SQLContext can be used create L{DataFrame}, register L{DataFrame} as + tables, execute SQL over tables, cache tables, and read parquet files. + """ + + def __init__(self, sparkContext, sqlContext=None): + """Create a new SQLContext. + + It will add a method called `toDF` to :class:`RDD`, which could be + used to convert an RDD into a DataFrame, it's a shorthand for + :func:`SQLContext.createDataFrame`. + + :param sparkContext: The SparkContext to wrap. + :param sqlContext: An optional JVM Scala SQLContext. If set, we do not instatiate a new + SQLContext in the JVM, instead we make all calls to this object. + + >>> from datetime import datetime + >>> sqlCtx = SQLContext(sc) + >>> allTypes = sc.parallelize([Row(i=1, s="string", d=1.0, l=1L, + ... b=True, list=[1, 2, 3], dict={"s": 0}, row=Row(a=1), + ... time=datetime(2014, 8, 1, 14, 1, 5))]) + >>> df = allTypes.toDF() + >>> df.registerTempTable("allTypes") + >>> sqlCtx.sql('select i+1, d+1, not b, list[1], dict["s"], time, row.a ' + ... 'from allTypes where b and i > 0').collect() + [Row(c0=2, c1=2.0, c2=False, c3=2, c4=0...8, 1, 14, 1, 5), a=1)] + >>> df.map(lambda x: (x.i, x.s, x.d, x.l, x.b, x.time, + ... x.row.a, x.list)).collect() + [(1, u'string', 1.0, 1, True, ...(2014, 8, 1, 14, 1, 5), 1, [1, 2, 3])] + """ + self._sc = sparkContext + self._jsc = self._sc._jsc + self._jvm = self._sc._jvm + self._scala_SQLContext = sqlContext + _monkey_patch_RDD(self) + + @property + def _ssql_ctx(self): + """Accessor for the JVM Spark SQL context. + + Subclasses can override this property to provide their own + JVM Contexts. + """ + if self._scala_SQLContext is None: + self._scala_SQLContext = self._jvm.SQLContext(self._jsc.sc()) + return self._scala_SQLContext + + def setConf(self, key, value): + """Sets the given Spark SQL configuration property. + """ + self._ssql_ctx.setConf(key, value) + + def getConf(self, key, defaultValue): + """Returns the value of Spark SQL configuration property for the given key. + + If the key is not set, returns defaultValue. + """ + return self._ssql_ctx.getConf(key, defaultValue) + + def registerFunction(self, name, f, returnType=StringType()): + """Registers a lambda function as a UDF so it can be used in SQL statements. + + In addition to a name and the function itself, the return type can be optionally specified. + When the return type is not given it default to a string and conversion will automatically + be done. For any other return type, the produced object must match the specified type. + + >>> sqlCtx.registerFunction("stringLengthString", lambda x: len(x)) + >>> sqlCtx.sql("SELECT stringLengthString('test')").collect() + [Row(c0=u'4')] + + >>> from pyspark.sql.types import IntegerType + >>> sqlCtx.registerFunction("stringLengthInt", lambda x: len(x), IntegerType()) + >>> sqlCtx.sql("SELECT stringLengthInt('test')").collect() + [Row(c0=4)] + """ + func = lambda _, it: imap(lambda x: f(*x), it) + ser = AutoBatchedSerializer(PickleSerializer()) + command = (func, None, ser, ser) + pickled_cmd, bvars, env, includes = _prepare_for_python_RDD(self._sc, command, self) + self._ssql_ctx.udf().registerPython(name, + bytearray(pickled_cmd), + env, + includes, + self._sc.pythonExec, + bvars, + self._sc._javaAccumulator, + returnType.json()) + + def _inferSchema(self, rdd, samplingRatio=None): + first = rdd.first() + if not first: + raise ValueError("The first row in RDD is empty, " + "can not infer schema") + if type(first) is dict: + warnings.warn("Using RDD of dict to inferSchema is deprecated," + "please use pyspark.sql.Row instead") + + if samplingRatio is None: + schema = _infer_schema(first) + if _has_nulltype(schema): + for row in rdd.take(100)[1:]: + schema = _merge_type(schema, _infer_schema(row)) + if not _has_nulltype(schema): + break + else: + raise ValueError("Some of types cannot be determined by the " + "first 100 rows, please try again with sampling") + else: + if samplingRatio < 0.99: + rdd = rdd.sample(False, float(samplingRatio)) + schema = rdd.map(_infer_schema).reduce(_merge_type) + return schema + + def inferSchema(self, rdd, samplingRatio=None): + """Infer and apply a schema to an RDD of L{Row}. + + ::note: + Deprecated in 1.3, use :func:`createDataFrame` instead + + When samplingRatio is specified, the schema is inferred by looking + at the types of each row in the sampled dataset. Otherwise, the + first 100 rows of the RDD are inspected. Nested collections are + supported, which can include array, dict, list, Row, tuple, + namedtuple, or object. + + Each row could be L{pyspark.sql.Row} object or namedtuple or objects. + Using top level dicts is deprecated, as dict is used to represent Maps. + + If a single column has multiple distinct inferred types, it may cause + runtime exceptions. + + >>> rdd = sc.parallelize( + ... [Row(field1=1, field2="row1"), + ... Row(field1=2, field2="row2"), + ... Row(field1=3, field2="row3")]) + >>> df = sqlCtx.inferSchema(rdd) + >>> df.collect()[0] + Row(field1=1, field2=u'row1') + """ + + if isinstance(rdd, DataFrame): + raise TypeError("Cannot apply schema to DataFrame") + + schema = self._inferSchema(rdd, samplingRatio) + converter = _create_converter(schema) + rdd = rdd.map(converter) + return self.applySchema(rdd, schema) + + def applySchema(self, rdd, schema): + """ + Applies the given schema to the given RDD of L{tuple} or L{list}. + + ::note: + Deprecated in 1.3, use :func:`createDataFrame` instead + + These tuples or lists can contain complex nested structures like + lists, maps or nested rows. + + The schema should be a StructType. + + It is important that the schema matches the types of the objects + in each row or exceptions could be thrown at runtime. + + >>> from pyspark.sql.types import * + >>> rdd2 = sc.parallelize([(1, "row1"), (2, "row2"), (3, "row3")]) + >>> schema = StructType([StructField("field1", IntegerType(), False), + ... StructField("field2", StringType(), False)]) + >>> df = sqlCtx.applySchema(rdd2, schema) + >>> df.collect() + [Row(field1=1, field2=u'row1'),..., Row(field1=3, field2=u'row3')] + """ + + if isinstance(rdd, DataFrame): + raise TypeError("Cannot apply schema to DataFrame") + + if not isinstance(schema, StructType): + raise TypeError("schema should be StructType, but got %s" % schema) + + # take the first few rows to verify schema + rows = rdd.take(10) + # Row() cannot been deserialized by Pyrolite + if rows and isinstance(rows[0], tuple) and rows[0].__class__.__name__ == 'Row': + rdd = rdd.map(tuple) + rows = rdd.take(10) + + for row in rows: + _verify_type(row, schema) + + # convert python objects to sql data + converter = _python_to_sql_converter(schema) + rdd = rdd.map(converter) + + jrdd = self._jvm.SerDeUtil.toJavaArray(rdd._to_java_object_rdd()) + df = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.json()) + return DataFrame(df, self) + + def createDataFrame(self, data, schema=None, samplingRatio=None): + """ + Create a DataFrame from an RDD of tuple/list, list or pandas.DataFrame. + + `schema` could be :class:`StructType` or a list of column names. + + When `schema` is a list of column names, the type of each column + will be inferred from `rdd`. + + When `schema` is None, it will try to infer the column name and type + from `rdd`, which should be an RDD of :class:`Row`, or namedtuple, + or dict. + + If referring needed, `samplingRatio` is used to determined how many + rows will be used to do referring. The first row will be used if + `samplingRatio` is None. + + :param data: an RDD of Row/tuple/list/dict, list, or pandas.DataFrame + :param schema: a StructType or list of names of columns + :param samplingRatio: the sample ratio of rows used for inferring + :return: a DataFrame + + >>> l = [('Alice', 1)] + >>> sqlCtx.createDataFrame(l).collect() + [Row(_1=u'Alice', _2=1)] + >>> sqlCtx.createDataFrame(l, ['name', 'age']).collect() + [Row(name=u'Alice', age=1)] + + >>> d = [{'name': 'Alice', 'age': 1}] + >>> sqlCtx.createDataFrame(d).collect() + [Row(age=1, name=u'Alice')] + + >>> rdd = sc.parallelize(l) + >>> sqlCtx.createDataFrame(rdd).collect() + [Row(_1=u'Alice', _2=1)] + >>> df = sqlCtx.createDataFrame(rdd, ['name', 'age']) + >>> df.collect() + [Row(name=u'Alice', age=1)] + + >>> from pyspark.sql import Row + >>> Person = Row('name', 'age') + >>> person = rdd.map(lambda r: Person(*r)) + >>> df2 = sqlCtx.createDataFrame(person) + >>> df2.collect() + [Row(name=u'Alice', age=1)] + + >>> from pyspark.sql.types import * + >>> schema = StructType([ + ... StructField("name", StringType(), True), + ... StructField("age", IntegerType(), True)]) + >>> df3 = sqlCtx.createDataFrame(rdd, schema) + >>> df3.collect() + [Row(name=u'Alice', age=1)] + + >>> sqlCtx.createDataFrame(df.toPandas()).collect() # doctest: +SKIP + [Row(name=u'Alice', age=1)] + """ + if isinstance(data, DataFrame): + raise TypeError("data is already a DataFrame") + + if has_pandas and isinstance(data, pandas.DataFrame): + if schema is None: + schema = list(data.columns) + data = [r.tolist() for r in data.to_records(index=False)] + + if not isinstance(data, RDD): + try: + # data could be list, tuple, generator ... + data = self._sc.parallelize(data) + except Exception: + raise ValueError("cannot create an RDD from type: %s" % type(data)) + + if schema is None: + return self.inferSchema(data, samplingRatio) + + if isinstance(schema, (list, tuple)): + first = data.first() + if not isinstance(first, (list, tuple)): + raise ValueError("each row in `rdd` should be list or tuple, " + "but got %r" % type(first)) + row_cls = Row(*schema) + schema = self._inferSchema(data.map(lambda r: row_cls(*r)), samplingRatio) + + return self.applySchema(data, schema) + + def registerDataFrameAsTable(self, rdd, tableName): + """Registers the given RDD as a temporary table in the catalog. + + Temporary tables exist only during the lifetime of this instance of + SQLContext. + + >>> sqlCtx.registerDataFrameAsTable(df, "table1") + """ + if (rdd.__class__ is DataFrame): + df = rdd._jdf + self._ssql_ctx.registerDataFrameAsTable(df, tableName) + else: + raise ValueError("Can only register DataFrame as table") + + def parquetFile(self, *paths): + """Loads a Parquet file, returning the result as a L{DataFrame}. + + >>> import tempfile, shutil + >>> parquetFile = tempfile.mkdtemp() + >>> shutil.rmtree(parquetFile) + >>> df.saveAsParquetFile(parquetFile) + >>> df2 = sqlCtx.parquetFile(parquetFile) + >>> sorted(df.collect()) == sorted(df2.collect()) + True + """ + gateway = self._sc._gateway + jpaths = gateway.new_array(gateway.jvm.java.lang.String, len(paths)) + for i in range(0, len(paths)): + jpaths[i] = paths[i] + jdf = self._ssql_ctx.parquetFile(jpaths) + return DataFrame(jdf, self) + + def jsonFile(self, path, schema=None, samplingRatio=1.0): + """ + Loads a text file storing one JSON object per line as a + L{DataFrame}. + + If the schema is provided, applies the given schema to this + JSON dataset. + + Otherwise, it samples the dataset with ratio `samplingRatio` to + determine the schema. + + >>> import tempfile, shutil + >>> jsonFile = tempfile.mkdtemp() + >>> shutil.rmtree(jsonFile) + >>> with open(jsonFile, 'w') as f: + ... f.writelines(jsonStrings) + >>> df1 = sqlCtx.jsonFile(jsonFile) + >>> df1.printSchema() + root + |-- field1: long (nullable = true) + |-- field2: string (nullable = true) + |-- field3: struct (nullable = true) + | |-- field4: long (nullable = true) + + >>> from pyspark.sql.types import * + >>> schema = StructType([ + ... StructField("field2", StringType()), + ... StructField("field3", + ... StructType([StructField("field5", ArrayType(IntegerType()))]))]) + >>> df2 = sqlCtx.jsonFile(jsonFile, schema) + >>> df2.printSchema() + root + |-- field2: string (nullable = true) + |-- field3: struct (nullable = true) + | |-- field5: array (nullable = true) + | | |-- element: integer (containsNull = true) + """ + if schema is None: + df = self._ssql_ctx.jsonFile(path, samplingRatio) + else: + scala_datatype = self._ssql_ctx.parseDataType(schema.json()) + df = self._ssql_ctx.jsonFile(path, scala_datatype) + return DataFrame(df, self) + + def jsonRDD(self, rdd, schema=None, samplingRatio=1.0): + """Loads an RDD storing one JSON object per string as a L{DataFrame}. + + If the schema is provided, applies the given schema to this + JSON dataset. + + Otherwise, it samples the dataset with ratio `samplingRatio` to + determine the schema. + + >>> df1 = sqlCtx.jsonRDD(json) + >>> df1.first() + Row(field1=1, field2=u'row1', field3=Row(field4=11, field5=None), field6=None) + + >>> df2 = sqlCtx.jsonRDD(json, df1.schema) + >>> df2.first() + Row(field1=1, field2=u'row1', field3=Row(field4=11, field5=None), field6=None) + + >>> from pyspark.sql.types import * + >>> schema = StructType([ + ... StructField("field2", StringType()), + ... StructField("field3", + ... StructType([StructField("field5", ArrayType(IntegerType()))])) + ... ]) + >>> df3 = sqlCtx.jsonRDD(json, schema) + >>> df3.first() + Row(field2=u'row1', field3=Row(field5=None)) + + """ + + def func(iterator): + for x in iterator: + if not isinstance(x, basestring): + x = unicode(x) + if isinstance(x, unicode): + x = x.encode("utf-8") + yield x + keyed = rdd.mapPartitions(func) + keyed._bypass_serializer = True + jrdd = keyed._jrdd.map(self._jvm.BytesToString()) + if schema is None: + df = self._ssql_ctx.jsonRDD(jrdd.rdd(), samplingRatio) + else: + scala_datatype = self._ssql_ctx.parseDataType(schema.json()) + df = self._ssql_ctx.jsonRDD(jrdd.rdd(), scala_datatype) + return DataFrame(df, self) + + def load(self, path=None, source=None, schema=None, **options): + """Returns the dataset in a data source as a DataFrame. + + The data source is specified by the `source` and a set of `options`. + If `source` is not specified, the default data source configured by + spark.sql.sources.default will be used. + + Optionally, a schema can be provided as the schema of the returned DataFrame. + """ + if path is not None: + options["path"] = path + if source is None: + source = self.getConf("spark.sql.sources.default", + "org.apache.spark.sql.parquet") + joptions = MapConverter().convert(options, + self._sc._gateway._gateway_client) + if schema is None: + df = self._ssql_ctx.load(source, joptions) + else: + if not isinstance(schema, StructType): + raise TypeError("schema should be StructType") + scala_datatype = self._ssql_ctx.parseDataType(schema.json()) + df = self._ssql_ctx.load(source, scala_datatype, joptions) + return DataFrame(df, self) + + def createExternalTable(self, tableName, path=None, source=None, + schema=None, **options): + """Creates an external table based on the dataset in a data source. + + It returns the DataFrame associated with the external table. + + The data source is specified by the `source` and a set of `options`. + If `source` is not specified, the default data source configured by + spark.sql.sources.default will be used. + + Optionally, a schema can be provided as the schema of the returned DataFrame and + created external table. + """ + if path is not None: + options["path"] = path + if source is None: + source = self.getConf("spark.sql.sources.default", + "org.apache.spark.sql.parquet") + joptions = MapConverter().convert(options, + self._sc._gateway._gateway_client) + if schema is None: + df = self._ssql_ctx.createExternalTable(tableName, source, joptions) + else: + if not isinstance(schema, StructType): + raise TypeError("schema should be StructType") + scala_datatype = self._ssql_ctx.parseDataType(schema.json()) + df = self._ssql_ctx.createExternalTable(tableName, source, scala_datatype, + joptions) + return DataFrame(df, self) + + def sql(self, sqlQuery): + """Return a L{DataFrame} representing the result of the given query. + + >>> sqlCtx.registerDataFrameAsTable(df, "table1") + >>> df2 = sqlCtx.sql("SELECT field1 AS f1, field2 as f2 from table1") + >>> df2.collect() + [Row(f1=1, f2=u'row1'), Row(f1=2, f2=u'row2'), Row(f1=3, f2=u'row3')] + """ + return DataFrame(self._ssql_ctx.sql(sqlQuery), self) + + def table(self, tableName): + """Returns the specified table as a L{DataFrame}. + + >>> sqlCtx.registerDataFrameAsTable(df, "table1") + >>> df2 = sqlCtx.table("table1") + >>> sorted(df.collect()) == sorted(df2.collect()) + True + """ + return DataFrame(self._ssql_ctx.table(tableName), self) + + def tables(self, dbName=None): + """Returns a DataFrame containing names of tables in the given database. + + If `dbName` is not specified, the current database will be used. + + The returned DataFrame has two columns, tableName and isTemporary + (a column with BooleanType indicating if a table is a temporary one or not). + + >>> sqlCtx.registerDataFrameAsTable(df, "table1") + >>> df2 = sqlCtx.tables() + >>> df2.filter("tableName = 'table1'").first() + Row(tableName=u'table1', isTemporary=True) + """ + if dbName is None: + return DataFrame(self._ssql_ctx.tables(), self) + else: + return DataFrame(self._ssql_ctx.tables(dbName), self) + + def tableNames(self, dbName=None): + """Returns a list of names of tables in the database `dbName`. + + If `dbName` is not specified, the current database will be used. + + >>> sqlCtx.registerDataFrameAsTable(df, "table1") + >>> "table1" in sqlCtx.tableNames() + True + >>> "table1" in sqlCtx.tableNames("db") + True + """ + if dbName is None: + return [name for name in self._ssql_ctx.tableNames()] + else: + return [name for name in self._ssql_ctx.tableNames(dbName)] + + def cacheTable(self, tableName): + """Caches the specified table in-memory.""" + self._ssql_ctx.cacheTable(tableName) + + def uncacheTable(self, tableName): + """Removes the specified table from the in-memory cache.""" + self._ssql_ctx.uncacheTable(tableName) + + def clearCache(self): + """Removes all cached tables from the in-memory cache. """ + self._ssql_ctx.clearCache() + + +class HiveContext(SQLContext): + + """A variant of Spark SQL that integrates with data stored in Hive. + + Configuration for Hive is read from hive-site.xml on the classpath. + It supports running both SQL and HiveQL commands. + """ + + def __init__(self, sparkContext, hiveContext=None): + """Create a new HiveContext. + + :param sparkContext: The SparkContext to wrap. + :param hiveContext: An optional JVM Scala HiveContext. If set, we do not instatiate a new + HiveContext in the JVM, instead we make all calls to this object. + """ + SQLContext.__init__(self, sparkContext) + + if hiveContext: + self._scala_HiveContext = hiveContext + + @property + def _ssql_ctx(self): + try: + if not hasattr(self, '_scala_HiveContext'): + self._scala_HiveContext = self._get_hive_ctx() + return self._scala_HiveContext + except Py4JError as e: + raise Exception("You must build Spark with Hive. " + "Export 'SPARK_HIVE=true' and run " + "build/sbt assembly", e) + + def _get_hive_ctx(self): + return self._jvm.HiveContext(self._jsc.sc()) + + +def _create_row(fields, values): + row = Row(*values) + row.__FIELDS__ = fields + return row + + +class Row(tuple): + + """ + A row in L{DataFrame}. The fields in it can be accessed like attributes. + + Row can be used to create a row object by using named arguments, + the fields will be sorted by names. + + >>> row = Row(name="Alice", age=11) + >>> row + Row(age=11, name='Alice') + >>> row.name, row.age + ('Alice', 11) + + Row also can be used to create another Row like class, then it + could be used to create Row objects, such as + + >>> Person = Row("name", "age") + >>> Person + + >>> Person("Alice", 11) + Row(name='Alice', age=11) + """ + + def __new__(self, *args, **kwargs): + if args and kwargs: + raise ValueError("Can not use both args " + "and kwargs to create Row") + if args: + # create row class or objects + return tuple.__new__(self, args) + + elif kwargs: + # create row objects + names = sorted(kwargs.keys()) + values = tuple(kwargs[n] for n in names) + row = tuple.__new__(self, values) + row.__FIELDS__ = names + return row + + else: + raise ValueError("No args or kwargs") + + def asDict(self): + """ + Return as an dict + """ + if not hasattr(self, "__FIELDS__"): + raise TypeError("Cannot convert a Row class into dict") + return dict(zip(self.__FIELDS__, self)) + + # let obect acs like class + def __call__(self, *args): + """create new Row object""" + return _create_row(self, args) + + def __getattr__(self, item): + if item.startswith("__"): + raise AttributeError(item) + try: + # it will be slow when it has many fields, + # but this will not be used in normal cases + idx = self.__FIELDS__.index(item) + return self[idx] + except IndexError: + raise AttributeError(item) + + def __reduce__(self): + if hasattr(self, "__FIELDS__"): + return (_create_row, (self.__FIELDS__, tuple(self))) + else: + return tuple.__reduce__(self) + + def __repr__(self): + if hasattr(self, "__FIELDS__"): + return "Row(%s)" % ", ".join("%s=%r" % (k, v) + for k, v in zip(self.__FIELDS__, self)) + else: + return "" % ", ".join(self) + + +def _test(): + import doctest + from pyspark.context import SparkContext + from pyspark.sql import Row, SQLContext + import pyspark.sql.context + globs = pyspark.sql.context.__dict__.copy() + sc = SparkContext('local[4]', 'PythonTest') + globs['sc'] = sc + globs['sqlCtx'] = sqlCtx = SQLContext(sc) + globs['rdd'] = rdd = sc.parallelize( + [Row(field1=1, field2="row1"), + Row(field1=2, field2="row2"), + Row(field1=3, field2="row3")] + ) + _monkey_patch_RDD(sqlCtx) + globs['df'] = rdd.toDF() + jsonStrings = [ + '{"field1": 1, "field2": "row1", "field3":{"field4":11}}', + '{"field1" : 2, "field3":{"field4":22, "field5": [10, 11]},' + '"field6":[{"field7": "row2"}]}', + '{"field1" : null, "field2": "row3", ' + '"field3":{"field4":33, "field5": []}}' + ] + globs['jsonStrings'] = jsonStrings + globs['json'] = sc.parallelize(jsonStrings) + (failure_count, test_count) = doctest.testmod( + pyspark.sql.context, globs=globs, + optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE) + globs['sc'].stop() + if failure_count: + exit(-1) + + +if __name__ == "__main__": + _test() diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py new file mode 100644 index 0000000000000..aec99017fbdc1 --- /dev/null +++ b/python/pyspark/sql/dataframe.py @@ -0,0 +1,1058 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import sys +import itertools +import warnings +import random +import os +from tempfile import NamedTemporaryFile + +from py4j.java_collections import ListConverter, MapConverter + +from pyspark.context import SparkContext +from pyspark.rdd import RDD +from pyspark.serializers import BatchedSerializer, PickleSerializer, UTF8Deserializer +from pyspark.storagelevel import StorageLevel +from pyspark.traceback_utils import SCCallSiteSync +from pyspark.sql.types import * +from pyspark.sql.types import _create_cls, _parse_datatype_json_string + + +__all__ = ["DataFrame", "GroupedData", "Column", "SchemaRDD"] + + +class DataFrame(object): + + """A collection of rows that have the same columns. + + A :class:`DataFrame` is equivalent to a relational table in Spark SQL, + and can be created using various functions in :class:`SQLContext`:: + + people = sqlContext.parquetFile("...") + + Once created, it can be manipulated using the various domain-specific-language + (DSL) functions defined in: :class:`DataFrame`, :class:`Column`. + + To select a column from the data frame, use the apply method:: + + ageCol = people.age + + Note that the :class:`Column` type can also be manipulated + through its various functions:: + + # The following creates a new column that increases everybody's age by 10. + people.age + 10 + + + A more concrete example:: + + # To create DataFrame using SQLContext + people = sqlContext.parquetFile("...") + department = sqlContext.parquetFile("...") + + people.filter(people.age > 30).join(department, people.deptId == department.id)) \ + .groupBy(department.name, "gender").agg({"salary": "avg", "age": "max"}) + """ + + def __init__(self, jdf, sql_ctx): + self._jdf = jdf + self.sql_ctx = sql_ctx + self._sc = sql_ctx and sql_ctx._sc + self.is_cached = False + self._schema = None # initialized lazily + + @property + def rdd(self): + """ + Return the content of the :class:`DataFrame` as an :class:`RDD` + of :class:`Row` s. + """ + if not hasattr(self, '_lazy_rdd'): + jrdd = self._jdf.javaToPython() + rdd = RDD(jrdd, self.sql_ctx._sc, BatchedSerializer(PickleSerializer())) + schema = self.schema + + def applySchema(it): + cls = _create_cls(schema) + return itertools.imap(cls, it) + + self._lazy_rdd = rdd.mapPartitions(applySchema) + + return self._lazy_rdd + + def toJSON(self, use_unicode=False): + """Convert a :class:`DataFrame` into a MappedRDD of JSON documents; one document per row. + + >>> df.toJSON().first() + '{"age":2,"name":"Alice"}' + """ + rdd = self._jdf.toJSON() + return RDD(rdd.toJavaRDD(), self._sc, UTF8Deserializer(use_unicode)) + + def saveAsParquetFile(self, path): + """Save the contents as a Parquet file, preserving the schema. + + Files that are written out using this method can be read back in as + a :class:`DataFrame` using the L{SQLContext.parquetFile} method. + + >>> import tempfile, shutil + >>> parquetFile = tempfile.mkdtemp() + >>> shutil.rmtree(parquetFile) + >>> df.saveAsParquetFile(parquetFile) + >>> df2 = sqlCtx.parquetFile(parquetFile) + >>> sorted(df2.collect()) == sorted(df.collect()) + True + """ + self._jdf.saveAsParquetFile(path) + + def registerTempTable(self, name): + """Registers this RDD as a temporary table using the given name. + + The lifetime of this temporary table is tied to the L{SQLContext} + that was used to create this DataFrame. + + >>> df.registerTempTable("people") + >>> df2 = sqlCtx.sql("select * from people") + >>> sorted(df.collect()) == sorted(df2.collect()) + True + """ + self._jdf.registerTempTable(name) + + def registerAsTable(self, name): + """DEPRECATED: use registerTempTable() instead""" + warnings.warn("Use registerTempTable instead of registerAsTable.", DeprecationWarning) + self.registerTempTable(name) + + def insertInto(self, tableName, overwrite=False): + """Inserts the contents of this :class:`DataFrame` into the specified table. + + Optionally overwriting any existing data. + """ + self._jdf.insertInto(tableName, overwrite) + + def _java_save_mode(self, mode): + """Returns the Java save mode based on the Python save mode represented by a string. + """ + jSaveMode = self._sc._jvm.org.apache.spark.sql.SaveMode + jmode = jSaveMode.ErrorIfExists + mode = mode.lower() + if mode == "append": + jmode = jSaveMode.Append + elif mode == "overwrite": + jmode = jSaveMode.Overwrite + elif mode == "ignore": + jmode = jSaveMode.Ignore + elif mode == "error": + pass + else: + raise ValueError( + "Only 'append', 'overwrite', 'ignore', and 'error' are acceptable save mode.") + return jmode + + def saveAsTable(self, tableName, source=None, mode="append", **options): + """Saves the contents of the :class:`DataFrame` to a data source as a table. + + The data source is specified by the `source` and a set of `options`. + If `source` is not specified, the default data source configured by + spark.sql.sources.default will be used. + + Additionally, mode is used to specify the behavior of the saveAsTable operation when + table already exists in the data source. There are four modes: + + * append: Contents of this :class:`DataFrame` are expected to be appended \ + to existing table. + * overwrite: Data in the existing table is expected to be overwritten by \ + the contents of this DataFrame. + * error: An exception is expected to be thrown. + * ignore: The save operation is expected to not save the contents of the \ + :class:`DataFrame` and to not change the existing table. + """ + if source is None: + source = self.sql_ctx.getConf("spark.sql.sources.default", + "org.apache.spark.sql.parquet") + jmode = self._java_save_mode(mode) + joptions = MapConverter().convert(options, + self.sql_ctx._sc._gateway._gateway_client) + self._jdf.saveAsTable(tableName, source, jmode, joptions) + + def save(self, path=None, source=None, mode="append", **options): + """Saves the contents of the :class:`DataFrame` to a data source. + + The data source is specified by the `source` and a set of `options`. + If `source` is not specified, the default data source configured by + spark.sql.sources.default will be used. + + Additionally, mode is used to specify the behavior of the save operation when + data already exists in the data source. There are four modes: + + * append: Contents of this :class:`DataFrame` are expected to be appended to existing data. + * overwrite: Existing data is expected to be overwritten by the contents of this DataFrame. + * error: An exception is expected to be thrown. + * ignore: The save operation is expected to not save the contents of \ + the :class:`DataFrame` and to not change the existing data. + """ + if path is not None: + options["path"] = path + if source is None: + source = self.sql_ctx.getConf("spark.sql.sources.default", + "org.apache.spark.sql.parquet") + jmode = self._java_save_mode(mode) + joptions = MapConverter().convert(options, + self._sc._gateway._gateway_client) + self._jdf.save(source, jmode, joptions) + + @property + def schema(self): + """Returns the schema of this :class:`DataFrame` (represented by + a L{StructType}). + + >>> df.schema + StructType(List(StructField(age,IntegerType,true),StructField(name,StringType,true))) + """ + if self._schema is None: + self._schema = _parse_datatype_json_string(self._jdf.schema().json()) + return self._schema + + def printSchema(self): + """Prints out the schema in the tree format. + + >>> df.printSchema() + root + |-- age: integer (nullable = true) + |-- name: string (nullable = true) + + """ + print (self._jdf.schema().treeString()) + + def explain(self, extended=False): + """ + Prints the plans (logical and physical) to the console for + debugging purpose. + + If extended is False, only prints the physical plan. + + >>> df.explain() + PhysicalRDD [age#0,name#1], MapPartitionsRDD[...] at mapPartitions at SQLContext.scala:... + + >>> df.explain(True) + == Parsed Logical Plan == + ... + == Analyzed Logical Plan == + ... + == Optimized Logical Plan == + ... + == Physical Plan == + ... + == RDD == + """ + if extended: + print self._jdf.queryExecution().toString() + else: + print self._jdf.queryExecution().executedPlan().toString() + + def isLocal(self): + """ + Returns True if the `collect` and `take` methods can be run locally + (without any Spark executors). + """ + return self._jdf.isLocal() + + def show(self, n=20): + """ + Print the first n rows. + + >>> df + DataFrame[age: int, name: string] + >>> df.show() + age name + 2 Alice + 5 Bob + """ + print self._jdf.showString(n).encode('utf8', 'ignore') + + def __repr__(self): + return "DataFrame[%s]" % (", ".join("%s: %s" % c for c in self.dtypes)) + + def count(self): + """Return the number of elements in this RDD. + + Unlike the base RDD implementation of count, this implementation + leverages the query optimizer to compute the count on the DataFrame, + which supports features such as filter pushdown. + + >>> df.count() + 2L + """ + return self._jdf.count() + + def collect(self): + """Return a list that contains all of the rows. + + Each object in the list is a Row, the fields can be accessed as + attributes. + + >>> df.collect() + [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')] + """ + with SCCallSiteSync(self._sc) as css: + bytesInJava = self._jdf.javaToPython().collect().iterator() + tempFile = NamedTemporaryFile(delete=False, dir=self._sc._temp_dir) + tempFile.close() + self._sc._writeToFile(bytesInJava, tempFile.name) + # Read the data into Python and deserialize it: + with open(tempFile.name, 'rb') as tempFile: + rs = list(BatchedSerializer(PickleSerializer()).load_stream(tempFile)) + os.unlink(tempFile.name) + cls = _create_cls(self.schema) + return [cls(r) for r in rs] + + def limit(self, num): + """Limit the result count to the number specified. + + >>> df.limit(1).collect() + [Row(age=2, name=u'Alice')] + >>> df.limit(0).collect() + [] + """ + jdf = self._jdf.limit(num) + return DataFrame(jdf, self.sql_ctx) + + def take(self, num): + """Take the first num rows of the RDD. + + Each object in the list is a Row, the fields can be accessed as + attributes. + + >>> df.take(2) + [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')] + """ + return self.limit(num).collect() + + def map(self, f): + """ Return a new RDD by applying a function to each Row + + It's a shorthand for df.rdd.map() + + >>> df.map(lambda p: p.name).collect() + [u'Alice', u'Bob'] + """ + return self.rdd.map(f) + + def flatMap(self, f): + """ Return a new RDD by first applying a function to all elements of this, + and then flattening the results. + + It's a shorthand for df.rdd.flatMap() + + >>> df.flatMap(lambda p: p.name).collect() + [u'A', u'l', u'i', u'c', u'e', u'B', u'o', u'b'] + """ + return self.rdd.flatMap(f) + + def mapPartitions(self, f, preservesPartitioning=False): + """ + Return a new RDD by applying a function to each partition. + + It's a shorthand for df.rdd.mapPartitions() + + >>> rdd = sc.parallelize([1, 2, 3, 4], 4) + >>> def f(iterator): yield 1 + >>> rdd.mapPartitions(f).sum() + 4 + """ + return self.rdd.mapPartitions(f, preservesPartitioning) + + def foreach(self, f): + """ + Applies a function to all rows of this DataFrame. + + It's a shorthand for df.rdd.foreach() + + >>> def f(person): + ... print person.name + >>> df.foreach(f) + """ + return self.rdd.foreach(f) + + def foreachPartition(self, f): + """ + Applies a function to each partition of this DataFrame. + + It's a shorthand for df.rdd.foreachPartition() + + >>> def f(people): + ... for person in people: + ... print person.name + >>> df.foreachPartition(f) + """ + return self.rdd.foreachPartition(f) + + def cache(self): + """ Persist with the default storage level (C{MEMORY_ONLY_SER}). + """ + self.is_cached = True + self._jdf.cache() + return self + + def persist(self, storageLevel=StorageLevel.MEMORY_ONLY_SER): + """ Set the storage level to persist its values across operations + after the first time it is computed. This can only be used to assign + a new storage level if the RDD does not have a storage level set yet. + If no storage level is specified defaults to (C{MEMORY_ONLY_SER}). + """ + self.is_cached = True + javaStorageLevel = self._sc._getJavaStorageLevel(storageLevel) + self._jdf.persist(javaStorageLevel) + return self + + def unpersist(self, blocking=True): + """ Mark it as non-persistent, and remove all blocks for it from + memory and disk. + """ + self.is_cached = False + self._jdf.unpersist(blocking) + return self + + # def coalesce(self, numPartitions, shuffle=False): + # rdd = self._jdf.coalesce(numPartitions, shuffle, None) + # return DataFrame(rdd, self.sql_ctx) + + def repartition(self, numPartitions): + """ Return a new :class:`DataFrame` that has exactly `numPartitions` + partitions. + + >>> df.repartition(10).rdd.getNumPartitions() + 10 + """ + return DataFrame(self._jdf.repartition(numPartitions), self.sql_ctx) + + def distinct(self): + """ + Return a new :class:`DataFrame` containing the distinct rows in this DataFrame. + + >>> df.distinct().count() + 2L + """ + return DataFrame(self._jdf.distinct(), self.sql_ctx) + + def sample(self, withReplacement, fraction, seed=None): + """ + Return a sampled subset of this DataFrame. + + >>> df.sample(False, 0.5, 97).count() + 1L + """ + assert fraction >= 0.0, "Negative fraction value: %s" % fraction + seed = seed if seed is not None else random.randint(0, sys.maxint) + rdd = self._jdf.sample(withReplacement, fraction, long(seed)) + return DataFrame(rdd, self.sql_ctx) + + @property + def dtypes(self): + """Return all column names and their data types as a list. + + >>> df.dtypes + [('age', 'int'), ('name', 'string')] + """ + return [(str(f.name), f.dataType.simpleString()) for f in self.schema.fields] + + @property + def columns(self): + """ Return all column names as a list. + + >>> df.columns + [u'age', u'name'] + """ + return [f.name for f in self.schema.fields] + + def join(self, other, joinExprs=None, joinType=None): + """ + Join with another :class:`DataFrame`, using the given join expression. + The following performs a full outer join between `df1` and `df2`. + + :param other: Right side of the join + :param joinExprs: Join expression + :param joinType: One of `inner`, `outer`, `left_outer`, `right_outer`, `semijoin`. + + >>> df.join(df2, df.name == df2.name, 'outer').select(df.name, df2.height).collect() + [Row(name=None, height=80), Row(name=u'Bob', height=85), Row(name=u'Alice', height=None)] + """ + + if joinExprs is None: + jdf = self._jdf.join(other._jdf) + else: + assert isinstance(joinExprs, Column), "joinExprs should be Column" + if joinType is None: + jdf = self._jdf.join(other._jdf, joinExprs._jc) + else: + assert isinstance(joinType, basestring), "joinType should be basestring" + jdf = self._jdf.join(other._jdf, joinExprs._jc, joinType) + return DataFrame(jdf, self.sql_ctx) + + def sort(self, *cols): + """ Return a new :class:`DataFrame` sorted by the specified column(s). + + :param cols: The columns or expressions used for sorting + + >>> df.sort(df.age.desc()).collect() + [Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')] + >>> df.orderBy(df.age.desc()).collect() + [Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')] + >>> from pyspark.sql.functions import * + >>> df.sort(asc("age")).collect() + [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')] + >>> df.orderBy(desc("age"), "name").collect() + [Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')] + """ + if not cols: + raise ValueError("should sort by at least one column") + jcols = ListConverter().convert([_to_java_column(c) for c in cols], + self._sc._gateway._gateway_client) + jdf = self._jdf.sort(self._sc._jvm.PythonUtils.toSeq(jcols)) + return DataFrame(jdf, self.sql_ctx) + + orderBy = sort + + def head(self, n=None): + """ Return the first `n` rows or the first row if n is None. + + >>> df.head() + Row(age=2, name=u'Alice') + >>> df.head(1) + [Row(age=2, name=u'Alice')] + """ + if n is None: + rs = self.head(1) + return rs[0] if rs else None + return self.take(n) + + def first(self): + """ Return the first row. + + >>> df.first() + Row(age=2, name=u'Alice') + """ + return self.head() + + def __getitem__(self, item): + """ Return the column by given name + + >>> df.select(df['age']).collect() + [Row(age=2), Row(age=5)] + >>> df[ ["name", "age"]].collect() + [Row(name=u'Alice', age=2), Row(name=u'Bob', age=5)] + >>> df[ df.age > 3 ].collect() + [Row(age=5, name=u'Bob')] + """ + if isinstance(item, basestring): + jc = self._jdf.apply(item) + return Column(jc) + elif isinstance(item, Column): + return self.filter(item) + elif isinstance(item, list): + return self.select(*item) + else: + raise IndexError("unexpected index: %s" % item) + + def __getattr__(self, name): + """ Return the column by given name + + >>> df.select(df.age).collect() + [Row(age=2), Row(age=5)] + """ + if name.startswith("__"): + raise AttributeError(name) + jc = self._jdf.apply(name) + return Column(jc) + + def select(self, *cols): + """ Selecting a set of expressions. + + >>> df.select('*').collect() + [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')] + >>> df.select('name', 'age').collect() + [Row(name=u'Alice', age=2), Row(name=u'Bob', age=5)] + >>> df.select(df.name, (df.age + 10).alias('age')).collect() + [Row(name=u'Alice', age=12), Row(name=u'Bob', age=15)] + """ + jcols = ListConverter().convert([_to_java_column(c) for c in cols], + self._sc._gateway._gateway_client) + jdf = self._jdf.select(self.sql_ctx._sc._jvm.PythonUtils.toSeq(jcols)) + return DataFrame(jdf, self.sql_ctx) + + def selectExpr(self, *expr): + """ + Selects a set of SQL expressions. This is a variant of + `select` that accepts SQL expressions. + + >>> df.selectExpr("age * 2", "abs(age)").collect() + [Row((age * 2)=4, Abs(age)=2), Row((age * 2)=10, Abs(age)=5)] + """ + jexpr = ListConverter().convert(expr, self._sc._gateway._gateway_client) + jdf = self._jdf.selectExpr(self._sc._jvm.PythonUtils.toSeq(jexpr)) + return DataFrame(jdf, self.sql_ctx) + + def filter(self, condition): + """ Filtering rows using the given condition, which could be + :class:`Column` expression or string of SQL expression. + + where() is an alias for filter(). + + >>> df.filter(df.age > 3).collect() + [Row(age=5, name=u'Bob')] + >>> df.where(df.age == 2).collect() + [Row(age=2, name=u'Alice')] + + >>> df.filter("age > 3").collect() + [Row(age=5, name=u'Bob')] + >>> df.where("age = 2").collect() + [Row(age=2, name=u'Alice')] + """ + if isinstance(condition, basestring): + jdf = self._jdf.filter(condition) + elif isinstance(condition, Column): + jdf = self._jdf.filter(condition._jc) + else: + raise TypeError("condition should be string or Column") + return DataFrame(jdf, self.sql_ctx) + + where = filter + + def groupBy(self, *cols): + """ Group the :class:`DataFrame` using the specified columns, + so we can run aggregation on them. See :class:`GroupedData` + for all the available aggregate functions. + + >>> df.groupBy().avg().collect() + [Row(AVG(age#0)=3.5)] + >>> df.groupBy('name').agg({'age': 'mean'}).collect() + [Row(name=u'Bob', AVG(age#0)=5.0), Row(name=u'Alice', AVG(age#0)=2.0)] + >>> df.groupBy(df.name).avg().collect() + [Row(name=u'Bob', AVG(age#0)=5.0), Row(name=u'Alice', AVG(age#0)=2.0)] + """ + jcols = ListConverter().convert([_to_java_column(c) for c in cols], + self._sc._gateway._gateway_client) + jdf = self._jdf.groupBy(self.sql_ctx._sc._jvm.PythonUtils.toSeq(jcols)) + return GroupedData(jdf, self.sql_ctx) + + def agg(self, *exprs): + """ Aggregate on the entire :class:`DataFrame` without groups + (shorthand for df.groupBy.agg()). + + >>> df.agg({"age": "max"}).collect() + [Row(MAX(age#0)=5)] + >>> from pyspark.sql import functions as F + >>> df.agg(F.min(df.age)).collect() + [Row(MIN(age#0)=2)] + """ + return self.groupBy().agg(*exprs) + + def unionAll(self, other): + """ Return a new :class:`DataFrame` containing union of rows in this + frame and another frame. + + This is equivalent to `UNION ALL` in SQL. + """ + return DataFrame(self._jdf.unionAll(other._jdf), self.sql_ctx) + + def intersect(self, other): + """ Return a new :class:`DataFrame` containing rows only in + both this frame and another frame. + + This is equivalent to `INTERSECT` in SQL. + """ + return DataFrame(self._jdf.intersect(other._jdf), self.sql_ctx) + + def subtract(self, other): + """ Return a new :class:`DataFrame` containing rows in this frame + but not in another frame. + + This is equivalent to `EXCEPT` in SQL. + """ + return DataFrame(getattr(self._jdf, "except")(other._jdf), self.sql_ctx) + + def withColumn(self, colName, col): + """ Return a new :class:`DataFrame` by adding a column. + + >>> df.withColumn('age2', df.age + 2).collect() + [Row(age=2, name=u'Alice', age2=4), Row(age=5, name=u'Bob', age2=7)] + """ + return self.select('*', col.alias(colName)) + + def withColumnRenamed(self, existing, new): + """ Rename an existing column to a new name + + >>> df.withColumnRenamed('age', 'age2').collect() + [Row(age2=2, name=u'Alice'), Row(age2=5, name=u'Bob')] + """ + cols = [Column(_to_java_column(c)).alias(new) + if c == existing else c + for c in self.columns] + return self.select(*cols) + + def toPandas(self): + """ + Collect all the rows and return a `pandas.DataFrame`. + + >>> df.toPandas() # doctest: +SKIP + age name + 0 2 Alice + 1 5 Bob + """ + import pandas as pd + return pd.DataFrame.from_records(self.collect(), columns=self.columns) + + +# Having SchemaRDD for backward compatibility (for docs) +class SchemaRDD(DataFrame): + """ + SchemaRDD is deprecated, please use DataFrame + """ + + +def dfapi(f): + def _api(self): + name = f.__name__ + jdf = getattr(self._jdf, name)() + return DataFrame(jdf, self.sql_ctx) + _api.__name__ = f.__name__ + _api.__doc__ = f.__doc__ + return _api + + +def df_varargs_api(f): + def _api(self, *args): + jargs = ListConverter().convert(args, + self.sql_ctx._sc._gateway._gateway_client) + name = f.__name__ + jdf = getattr(self._jdf, name)(self.sql_ctx._sc._jvm.PythonUtils.toSeq(jargs)) + return DataFrame(jdf, self.sql_ctx) + _api.__name__ = f.__name__ + _api.__doc__ = f.__doc__ + return _api + + +class GroupedData(object): + + """ + A set of methods for aggregations on a :class:`DataFrame`, + created by DataFrame.groupBy(). + """ + + def __init__(self, jdf, sql_ctx): + self._jdf = jdf + self.sql_ctx = sql_ctx + + def agg(self, *exprs): + """ Compute aggregates by specifying a map from column name + to aggregate methods. + + The available aggregate methods are `avg`, `max`, `min`, + `sum`, `count`. + + :param exprs: list or aggregate columns or a map from column + name to aggregate methods. + + >>> gdf = df.groupBy(df.name) + >>> gdf.agg({"*": "count"}).collect() + [Row(name=u'Bob', COUNT(1)=1), Row(name=u'Alice', COUNT(1)=1)] + + >>> from pyspark.sql import functions as F + >>> gdf.agg(F.min(df.age)).collect() + [Row(MIN(age#0)=5), Row(MIN(age#0)=2)] + """ + assert exprs, "exprs should not be empty" + if len(exprs) == 1 and isinstance(exprs[0], dict): + jmap = MapConverter().convert(exprs[0], + self.sql_ctx._sc._gateway._gateway_client) + jdf = self._jdf.agg(jmap) + else: + # Columns + assert all(isinstance(c, Column) for c in exprs), "all exprs should be Column" + jcols = ListConverter().convert([c._jc for c in exprs[1:]], + self.sql_ctx._sc._gateway._gateway_client) + jdf = self._jdf.agg(exprs[0]._jc, self.sql_ctx._sc._jvm.PythonUtils.toSeq(jcols)) + return DataFrame(jdf, self.sql_ctx) + + @dfapi + def count(self): + """ Count the number of rows for each group. + + >>> df.groupBy(df.age).count().collect() + [Row(age=2, count=1), Row(age=5, count=1)] + """ + + @df_varargs_api + def mean(self, *cols): + """Compute the average value for each numeric columns + for each group. This is an alias for `avg`. + + >>> df.groupBy().mean('age').collect() + [Row(AVG(age#0)=3.5)] + >>> df3.groupBy().mean('age', 'height').collect() + [Row(AVG(age#4L)=3.5, AVG(height#5L)=82.5)] + """ + + @df_varargs_api + def avg(self, *cols): + """Compute the average value for each numeric columns + for each group. + + >>> df.groupBy().avg('age').collect() + [Row(AVG(age#0)=3.5)] + >>> df3.groupBy().avg('age', 'height').collect() + [Row(AVG(age#4L)=3.5, AVG(height#5L)=82.5)] + """ + + @df_varargs_api + def max(self, *cols): + """Compute the max value for each numeric columns for + each group. + + >>> df.groupBy().max('age').collect() + [Row(MAX(age#0)=5)] + >>> df3.groupBy().max('age', 'height').collect() + [Row(MAX(age#4L)=5, MAX(height#5L)=85)] + """ + + @df_varargs_api + def min(self, *cols): + """Compute the min value for each numeric column for + each group. + + >>> df.groupBy().min('age').collect() + [Row(MIN(age#0)=2)] + >>> df3.groupBy().min('age', 'height').collect() + [Row(MIN(age#4L)=2, MIN(height#5L)=80)] + """ + + @df_varargs_api + def sum(self, *cols): + """Compute the sum for each numeric columns for each + group. + + >>> df.groupBy().sum('age').collect() + [Row(SUM(age#0)=7)] + >>> df3.groupBy().sum('age', 'height').collect() + [Row(SUM(age#4L)=7, SUM(height#5L)=165)] + """ + + +def _create_column_from_literal(literal): + sc = SparkContext._active_spark_context + return sc._jvm.functions.lit(literal) + + +def _create_column_from_name(name): + sc = SparkContext._active_spark_context + return sc._jvm.functions.col(name) + + +def _to_java_column(col): + if isinstance(col, Column): + jcol = col._jc + else: + jcol = _create_column_from_name(col) + return jcol + + +def _unary_op(name, doc="unary operator"): + """ Create a method for given unary operator """ + def _(self): + jc = getattr(self._jc, name)() + return Column(jc) + _.__doc__ = doc + return _ + + +def _func_op(name, doc=''): + def _(self): + sc = SparkContext._active_spark_context + jc = getattr(sc._jvm.functions, name)(self._jc) + return Column(jc) + _.__doc__ = doc + return _ + + +def _bin_op(name, doc="binary operator"): + """ Create a method for given binary operator + """ + def _(self, other): + jc = other._jc if isinstance(other, Column) else other + njc = getattr(self._jc, name)(jc) + return Column(njc) + _.__doc__ = doc + return _ + + +def _reverse_op(name, doc="binary operator"): + """ Create a method for binary operator (this object is on right side) + """ + def _(self, other): + jother = _create_column_from_literal(other) + jc = getattr(jother, name)(self._jc) + return Column(jc) + _.__doc__ = doc + return _ + + +class Column(object): + + """ + A column in a DataFrame. + + :class:`Column` instances can be created by:: + + # 1. Select a column out of a DataFrame + + df.colName + df["colName"] + + # 2. Create from an expression + df.colName + 1 + 1 / df.colName + """ + + def __init__(self, jc): + self._jc = jc + + # arithmetic operators + __neg__ = _func_op("negate") + __add__ = _bin_op("plus") + __sub__ = _bin_op("minus") + __mul__ = _bin_op("multiply") + __div__ = _bin_op("divide") + __mod__ = _bin_op("mod") + __radd__ = _bin_op("plus") + __rsub__ = _reverse_op("minus") + __rmul__ = _bin_op("multiply") + __rdiv__ = _reverse_op("divide") + __rmod__ = _reverse_op("mod") + + # logistic operators + __eq__ = _bin_op("equalTo") + __ne__ = _bin_op("notEqual") + __lt__ = _bin_op("lt") + __le__ = _bin_op("leq") + __ge__ = _bin_op("geq") + __gt__ = _bin_op("gt") + + # `and`, `or`, `not` cannot be overloaded in Python, + # so use bitwise operators as boolean operators + __and__ = _bin_op('and') + __or__ = _bin_op('or') + __invert__ = _func_op('not') + __rand__ = _bin_op("and") + __ror__ = _bin_op("or") + + # container operators + __contains__ = _bin_op("contains") + __getitem__ = _bin_op("getItem") + getField = _bin_op("getField", "An expression that gets a field by name in a StructField.") + + # string methods + rlike = _bin_op("rlike") + like = _bin_op("like") + startswith = _bin_op("startsWith") + endswith = _bin_op("endsWith") + + def substr(self, startPos, length): + """ + Return a :class:`Column` which is a substring of the column + + :param startPos: start position (int or Column) + :param length: length of the substring (int or Column) + + >>> df.select(df.name.substr(1, 3).alias("col")).collect() + [Row(col=u'Ali'), Row(col=u'Bob')] + """ + if type(startPos) != type(length): + raise TypeError("Can not mix the type") + if isinstance(startPos, (int, long)): + jc = self._jc.substr(startPos, length) + elif isinstance(startPos, Column): + jc = self._jc.substr(startPos._jc, length._jc) + else: + raise TypeError("Unexpected type: %s" % type(startPos)) + return Column(jc) + + __getslice__ = substr + + # order + asc = _unary_op("asc", "Returns a sort expression based on the" + " ascending order of the given column name.") + desc = _unary_op("desc", "Returns a sort expression based on the" + " descending order of the given column name.") + + isNull = _unary_op("isNull", "True if the current expression is null.") + isNotNull = _unary_op("isNotNull", "True if the current expression is not null.") + + def alias(self, alias): + """Return a alias for this column + + >>> df.select(df.age.alias("age2")).collect() + [Row(age2=2), Row(age2=5)] + """ + return Column(getattr(self._jc, "as")(alias)) + + def cast(self, dataType): + """ Convert the column into type `dataType` + + >>> df.select(df.age.cast("string").alias('ages')).collect() + [Row(ages=u'2'), Row(ages=u'5')] + >>> df.select(df.age.cast(StringType()).alias('ages')).collect() + [Row(ages=u'2'), Row(ages=u'5')] + """ + if isinstance(dataType, basestring): + jc = self._jc.cast(dataType) + elif isinstance(dataType, DataType): + sc = SparkContext._active_spark_context + ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc()) + jdt = ssql_ctx.parseDataType(dataType.json()) + jc = self._jc.cast(jdt) + return Column(jc) + + def __repr__(self): + return 'Column<%s>' % self._jdf.toString().encode('utf8') + + +def _test(): + import doctest + from pyspark.context import SparkContext + from pyspark.sql import Row, SQLContext + import pyspark.sql.dataframe + globs = pyspark.sql.dataframe.__dict__.copy() + sc = SparkContext('local[4]', 'PythonTest') + globs['sc'] = sc + globs['sqlCtx'] = SQLContext(sc) + globs['df'] = sc.parallelize([(2, 'Alice'), (5, 'Bob')])\ + .toDF(StructType([StructField('age', IntegerType()), + StructField('name', StringType())])) + globs['df2'] = sc.parallelize([Row(name='Tom', height=80), Row(name='Bob', height=85)]).toDF() + globs['df3'] = sc.parallelize([Row(name='Alice', age=2, height=80), + Row(name='Bob', age=5, height=85)]).toDF() + (failure_count, test_count) = doctest.testmod( + pyspark.sql.dataframe, globs=globs, + optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF) + globs['sc'].stop() + if failure_count: + exit(-1) + + +if __name__ == "__main__": + _test() diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py new file mode 100644 index 0000000000000..5873f09ae3275 --- /dev/null +++ b/python/pyspark/sql/functions.py @@ -0,0 +1,174 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +A collections of builtin functions +""" + +from itertools import imap + +from py4j.java_collections import ListConverter + +from pyspark import SparkContext +from pyspark.rdd import _prepare_for_python_RDD +from pyspark.serializers import PickleSerializer, AutoBatchedSerializer +from pyspark.sql.types import StringType +from pyspark.sql.dataframe import Column, _to_java_column + + +__all__ = ['countDistinct', 'approxCountDistinct', 'udf'] + + +def _create_function(name, doc=""): + """ Create a function for aggregator by name""" + def _(col): + sc = SparkContext._active_spark_context + jc = getattr(sc._jvm.functions, name)(col._jc if isinstance(col, Column) else col) + return Column(jc) + _.__name__ = name + _.__doc__ = doc + return _ + + +_functions = { + 'lit': 'Creates a :class:`Column` of literal value.', + 'col': 'Returns a :class:`Column` based on the given column name.', + 'column': 'Returns a :class:`Column` based on the given column name.', + 'asc': 'Returns a sort expression based on the ascending order of the given column name.', + 'desc': 'Returns a sort expression based on the descending order of the given column name.', + + 'upper': 'Converts a string expression to upper case.', + 'lower': 'Converts a string expression to upper case.', + 'sqrt': 'Computes the square root of the specified float value.', + 'abs': 'Computes the absolutle value.', + + 'max': 'Aggregate function: returns the maximum value of the expression in a group.', + 'min': 'Aggregate function: returns the minimum value of the expression in a group.', + 'first': 'Aggregate function: returns the first value in a group.', + 'last': 'Aggregate function: returns the last value in a group.', + 'count': 'Aggregate function: returns the number of items in a group.', + 'sum': 'Aggregate function: returns the sum of all values in the expression.', + 'avg': 'Aggregate function: returns the average of the values in a group.', + 'mean': 'Aggregate function: returns the average of the values in a group.', + 'sumDistinct': 'Aggregate function: returns the sum of distinct values in the expression.', +} + + +for _name, _doc in _functions.items(): + globals()[_name] = _create_function(_name, _doc) +del _name, _doc +__all__ += _functions.keys() +__all__.sort() + + +def countDistinct(col, *cols): + """ Return a new Column for distinct count of `col` or `cols` + + >>> df.agg(countDistinct(df.age, df.name).alias('c')).collect() + [Row(c=2)] + + >>> df.agg(countDistinct("age", "name").alias('c')).collect() + [Row(c=2)] + """ + sc = SparkContext._active_spark_context + jcols = ListConverter().convert([_to_java_column(c) for c in cols], sc._gateway._gateway_client) + jc = sc._jvm.functions.countDistinct(_to_java_column(col), sc._jvm.PythonUtils.toSeq(jcols)) + return Column(jc) + + +def approxCountDistinct(col, rsd=None): + """ Return a new Column for approximate distinct count of `col` + + >>> df.agg(approxCountDistinct(df.age).alias('c')).collect() + [Row(c=2)] + """ + sc = SparkContext._active_spark_context + if rsd is None: + jc = sc._jvm.functions.approxCountDistinct(_to_java_column(col)) + else: + jc = sc._jvm.functions.approxCountDistinct(_to_java_column(col), rsd) + return Column(jc) + + +class UserDefinedFunction(object): + """ + User defined function in Python + """ + def __init__(self, func, returnType): + self.func = func + self.returnType = returnType + self._broadcast = None + self._judf = self._create_judf() + + def _create_judf(self): + f = self.func # put it in closure `func` + func = lambda _, it: imap(lambda x: f(*x), it) + ser = AutoBatchedSerializer(PickleSerializer()) + command = (func, None, ser, ser) + sc = SparkContext._active_spark_context + pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command, self) + ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc()) + jdt = ssql_ctx.parseDataType(self.returnType.json()) + judf = sc._jvm.UserDefinedPythonFunction(f.__name__, bytearray(pickled_command), env, + includes, sc.pythonExec, broadcast_vars, + sc._javaAccumulator, jdt) + return judf + + def __del__(self): + if self._broadcast is not None: + self._broadcast.unpersist() + self._broadcast = None + + def __call__(self, *cols): + sc = SparkContext._active_spark_context + jcols = ListConverter().convert([_to_java_column(c) for c in cols], + sc._gateway._gateway_client) + jc = self._judf.apply(sc._jvm.PythonUtils.toSeq(jcols)) + return Column(jc) + + +def udf(f, returnType=StringType()): + """Create a user defined function (UDF) + + >>> from pyspark.sql.types import IntegerType + >>> slen = udf(lambda s: len(s), IntegerType()) + >>> df.select(slen(df.name).alias('slen')).collect() + [Row(slen=5), Row(slen=3)] + """ + return UserDefinedFunction(f, returnType) + + +def _test(): + import doctest + from pyspark.context import SparkContext + from pyspark.sql import Row, SQLContext + import pyspark.sql.functions + globs = pyspark.sql.functions.__dict__.copy() + sc = SparkContext('local[4]', 'PythonTest') + globs['sc'] = sc + globs['sqlCtx'] = SQLContext(sc) + globs['df'] = sc.parallelize([Row(name='Alice', age=2), Row(name='Bob', age=5)]).toDF() + (failure_count, test_count) = doctest.testmod( + pyspark.sql.functions, globs=globs, + optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE) + globs['sc'].stop() + if failure_count: + exit(-1) + + +if __name__ == "__main__": + _test() diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py new file mode 100644 index 0000000000000..83899ad4b1b12 --- /dev/null +++ b/python/pyspark/sql/tests.py @@ -0,0 +1,487 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Unit tests for pyspark.sql; additional tests are implemented as doctests in +individual modules. +""" +import os +import sys +import pydoc +import shutil +import tempfile + +import py4j + +if sys.version_info[:2] <= (2, 6): + try: + import unittest2 as unittest + except ImportError: + sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') + sys.exit(1) +else: + import unittest + +from pyspark.sql import SQLContext, HiveContext, Column, Row +from pyspark.sql.types import * +from pyspark.sql.types import UserDefinedType, _infer_type +from pyspark.tests import ReusedPySparkTestCase + + +class ExamplePointUDT(UserDefinedType): + """ + User-defined type (UDT) for ExamplePoint. + """ + + @classmethod + def sqlType(self): + return ArrayType(DoubleType(), False) + + @classmethod + def module(cls): + return 'pyspark.tests' + + @classmethod + def scalaUDT(cls): + return 'org.apache.spark.sql.test.ExamplePointUDT' + + def serialize(self, obj): + return [obj.x, obj.y] + + def deserialize(self, datum): + return ExamplePoint(datum[0], datum[1]) + + +class ExamplePoint: + """ + An example class to demonstrate UDT in Scala, Java, and Python. + """ + + __UDT__ = ExamplePointUDT() + + def __init__(self, x, y): + self.x = x + self.y = y + + def __repr__(self): + return "ExamplePoint(%s,%s)" % (self.x, self.y) + + def __str__(self): + return "(%s,%s)" % (self.x, self.y) + + def __eq__(self, other): + return isinstance(other, ExamplePoint) and \ + other.x == self.x and other.y == self.y + + +class SQLTests(ReusedPySparkTestCase): + + @classmethod + def setUpClass(cls): + ReusedPySparkTestCase.setUpClass() + cls.tempdir = tempfile.NamedTemporaryFile(delete=False) + os.unlink(cls.tempdir.name) + cls.sqlCtx = SQLContext(cls.sc) + cls.testData = [Row(key=i, value=str(i)) for i in range(100)] + rdd = cls.sc.parallelize(cls.testData) + cls.df = rdd.toDF() + + @classmethod + def tearDownClass(cls): + ReusedPySparkTestCase.tearDownClass() + shutil.rmtree(cls.tempdir.name, ignore_errors=True) + + def test_udf(self): + self.sqlCtx.registerFunction("twoArgs", lambda x, y: len(x) + y, IntegerType()) + [row] = self.sqlCtx.sql("SELECT twoArgs('test', 1)").collect() + self.assertEqual(row[0], 5) + + def test_udf2(self): + self.sqlCtx.registerFunction("strlen", lambda string: len(string), IntegerType()) + self.sqlCtx.createDataFrame(self.sc.parallelize([Row(a="test")])).registerTempTable("test") + [res] = self.sqlCtx.sql("SELECT strlen(a) FROM test WHERE strlen(a) > 1").collect() + self.assertEqual(4, res[0]) + + def test_udf_with_array_type(self): + d = [Row(l=range(3), d={"key": range(5)})] + rdd = self.sc.parallelize(d) + self.sqlCtx.createDataFrame(rdd).registerTempTable("test") + self.sqlCtx.registerFunction("copylist", lambda l: list(l), ArrayType(IntegerType())) + self.sqlCtx.registerFunction("maplen", lambda d: len(d), IntegerType()) + [(l1, l2)] = self.sqlCtx.sql("select copylist(l), maplen(d) from test").collect() + self.assertEqual(range(3), l1) + self.assertEqual(1, l2) + + def test_broadcast_in_udf(self): + bar = {"a": "aa", "b": "bb", "c": "abc"} + foo = self.sc.broadcast(bar) + self.sqlCtx.registerFunction("MYUDF", lambda x: foo.value[x] if x else '') + [res] = self.sqlCtx.sql("SELECT MYUDF('c')").collect() + self.assertEqual("abc", res[0]) + [res] = self.sqlCtx.sql("SELECT MYUDF('')").collect() + self.assertEqual("", res[0]) + + def test_basic_functions(self): + rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}']) + df = self.sqlCtx.jsonRDD(rdd) + df.count() + df.collect() + df.schema + + # cache and checkpoint + self.assertFalse(df.is_cached) + df.persist() + df.unpersist() + df.cache() + self.assertTrue(df.is_cached) + self.assertEqual(2, df.count()) + + df.registerTempTable("temp") + df = self.sqlCtx.sql("select foo from temp") + df.count() + df.collect() + + def test_apply_schema_to_row(self): + df = self.sqlCtx.jsonRDD(self.sc.parallelize(["""{"a":2}"""])) + df2 = self.sqlCtx.createDataFrame(df.map(lambda x: x), df.schema) + self.assertEqual(df.collect(), df2.collect()) + + rdd = self.sc.parallelize(range(10)).map(lambda x: Row(a=x)) + df3 = self.sqlCtx.createDataFrame(rdd, df.schema) + self.assertEqual(10, df3.count()) + + def test_serialize_nested_array_and_map(self): + d = [Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")})] + rdd = self.sc.parallelize(d) + df = self.sqlCtx.createDataFrame(rdd) + row = df.head() + self.assertEqual(1, len(row.l)) + self.assertEqual(1, row.l[0].a) + self.assertEqual("2", row.d["key"].d) + + l = df.map(lambda x: x.l).first() + self.assertEqual(1, len(l)) + self.assertEqual('s', l[0].b) + + d = df.map(lambda x: x.d).first() + self.assertEqual(1, len(d)) + self.assertEqual(1.0, d["key"].c) + + row = df.map(lambda x: x.d["key"]).first() + self.assertEqual(1.0, row.c) + self.assertEqual("2", row.d) + + def test_infer_schema(self): + d = [Row(l=[], d={}, s=None), + Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")}, s="")] + rdd = self.sc.parallelize(d) + df = self.sqlCtx.createDataFrame(rdd) + self.assertEqual([], df.map(lambda r: r.l).first()) + self.assertEqual([None, ""], df.map(lambda r: r.s).collect()) + df.registerTempTable("test") + result = self.sqlCtx.sql("SELECT l[0].a from test where d['key'].d = '2'") + self.assertEqual(1, result.head()[0]) + + df2 = self.sqlCtx.createDataFrame(rdd, samplingRatio=1.0) + self.assertEqual(df.schema, df2.schema) + self.assertEqual({}, df2.map(lambda r: r.d).first()) + self.assertEqual([None, ""], df2.map(lambda r: r.s).collect()) + df2.registerTempTable("test2") + result = self.sqlCtx.sql("SELECT l[0].a from test2 where d['key'].d = '2'") + self.assertEqual(1, result.head()[0]) + + def test_infer_nested_schema(self): + NestedRow = Row("f1", "f2") + nestedRdd1 = self.sc.parallelize([NestedRow([1, 2], {"row1": 1.0}), + NestedRow([2, 3], {"row2": 2.0})]) + df = self.sqlCtx.inferSchema(nestedRdd1) + self.assertEqual(Row(f1=[1, 2], f2={u'row1': 1.0}), df.collect()[0]) + + nestedRdd2 = self.sc.parallelize([NestedRow([[1, 2], [2, 3]], [1, 2]), + NestedRow([[2, 3], [3, 4]], [2, 3])]) + df = self.sqlCtx.inferSchema(nestedRdd2) + self.assertEqual(Row(f1=[[1, 2], [2, 3]], f2=[1, 2]), df.collect()[0]) + + from collections import namedtuple + CustomRow = namedtuple('CustomRow', 'field1 field2') + rdd = self.sc.parallelize([CustomRow(field1=1, field2="row1"), + CustomRow(field1=2, field2="row2"), + CustomRow(field1=3, field2="row3")]) + df = self.sqlCtx.inferSchema(rdd) + self.assertEquals(Row(field1=1, field2=u'row1'), df.first()) + + def test_apply_schema(self): + from datetime import date, datetime + rdd = self.sc.parallelize([(127, -128L, -32768, 32767, 2147483647L, 1.0, + date(2010, 1, 1), datetime(2010, 1, 1, 1, 1, 1), + {"a": 1}, (2,), [1, 2, 3], None)]) + schema = StructType([ + StructField("byte1", ByteType(), False), + StructField("byte2", ByteType(), False), + StructField("short1", ShortType(), False), + StructField("short2", ShortType(), False), + StructField("int1", IntegerType(), False), + StructField("float1", FloatType(), False), + StructField("date1", DateType(), False), + StructField("time1", TimestampType(), False), + StructField("map1", MapType(StringType(), IntegerType(), False), False), + StructField("struct1", StructType([StructField("b", ShortType(), False)]), False), + StructField("list1", ArrayType(ByteType(), False), False), + StructField("null1", DoubleType(), True)]) + df = self.sqlCtx.applySchema(rdd, schema) + results = df.map(lambda x: (x.byte1, x.byte2, x.short1, x.short2, x.int1, x.float1, x.date1, + x.time1, x.map1["a"], x.struct1.b, x.list1, x.null1)) + r = (127, -128, -32768, 32767, 2147483647, 1.0, date(2010, 1, 1), + datetime(2010, 1, 1, 1, 1, 1), 1, 2, [1, 2, 3], None) + self.assertEqual(r, results.first()) + + df.registerTempTable("table2") + r = self.sqlCtx.sql("SELECT byte1 - 1 AS byte1, byte2 + 1 AS byte2, " + + "short1 + 1 AS short1, short2 - 1 AS short2, int1 - 1 AS int1, " + + "float1 + 1.5 as float1 FROM table2").first() + + self.assertEqual((126, -127, -32767, 32766, 2147483646, 2.5), tuple(r)) + + from pyspark.sql.types import _parse_schema_abstract, _infer_schema_type + rdd = self.sc.parallelize([(127, -32768, 1.0, datetime(2010, 1, 1, 1, 1, 1), + {"a": 1}, (2,), [1, 2, 3])]) + abstract = "byte1 short1 float1 time1 map1{} struct1(b) list1[]" + schema = _parse_schema_abstract(abstract) + typedSchema = _infer_schema_type(rdd.first(), schema) + df = self.sqlCtx.applySchema(rdd, typedSchema) + r = (127, -32768, 1.0, datetime(2010, 1, 1, 1, 1, 1), {"a": 1}, Row(b=2), [1, 2, 3]) + self.assertEqual(r, tuple(df.first())) + + def test_struct_in_map(self): + d = [Row(m={Row(i=1): Row(s="")})] + df = self.sc.parallelize(d).toDF() + k, v = df.head().m.items()[0] + self.assertEqual(1, k.i) + self.assertEqual("", v.s) + + def test_convert_row_to_dict(self): + row = Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")}) + self.assertEqual(1, row.asDict()['l'][0].a) + df = self.sc.parallelize([row]).toDF() + df.registerTempTable("test") + row = self.sqlCtx.sql("select l, d from test").head() + self.assertEqual(1, row.asDict()["l"][0].a) + self.assertEqual(1.0, row.asDict()['d']['key'].c) + + def test_infer_schema_with_udt(self): + from pyspark.sql.tests import ExamplePoint, ExamplePointUDT + row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) + df = self.sc.parallelize([row]).toDF() + schema = df.schema + field = [f for f in schema.fields if f.name == "point"][0] + self.assertEqual(type(field.dataType), ExamplePointUDT) + df.registerTempTable("labeled_point") + point = self.sqlCtx.sql("SELECT point FROM labeled_point").head().point + self.assertEqual(point, ExamplePoint(1.0, 2.0)) + + def test_apply_schema_with_udt(self): + from pyspark.sql.tests import ExamplePoint, ExamplePointUDT + row = (1.0, ExamplePoint(1.0, 2.0)) + rdd = self.sc.parallelize([row]) + schema = StructType([StructField("label", DoubleType(), False), + StructField("point", ExamplePointUDT(), False)]) + df = rdd.toDF(schema) + point = df.head().point + self.assertEquals(point, ExamplePoint(1.0, 2.0)) + + def test_parquet_with_udt(self): + from pyspark.sql.tests import ExamplePoint + row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) + df0 = self.sc.parallelize([row]).toDF() + output_dir = os.path.join(self.tempdir.name, "labeled_point") + df0.saveAsParquetFile(output_dir) + df1 = self.sqlCtx.parquetFile(output_dir) + point = df1.head().point + self.assertEquals(point, ExamplePoint(1.0, 2.0)) + + def test_column_operators(self): + ci = self.df.key + cs = self.df.value + c = ci == cs + self.assertTrue(isinstance((- ci - 1 - 2) % 3 * 2.5 / 3.5, Column)) + rcc = (1 + ci), (1 - ci), (1 * ci), (1 / ci), (1 % ci) + self.assertTrue(all(isinstance(c, Column) for c in rcc)) + cb = [ci == 5, ci != 0, ci > 3, ci < 4, ci >= 0, ci <= 7, ci and cs, ci or cs] + self.assertTrue(all(isinstance(c, Column) for c in cb)) + cbool = (ci & ci), (ci | ci), (~ci) + self.assertTrue(all(isinstance(c, Column) for c in cbool)) + css = cs.like('a'), cs.rlike('a'), cs.asc(), cs.desc(), cs.startswith('a'), cs.endswith('a') + self.assertTrue(all(isinstance(c, Column) for c in css)) + self.assertTrue(isinstance(ci.cast(LongType()), Column)) + + def test_column_select(self): + df = self.df + self.assertEqual(self.testData, df.select("*").collect()) + self.assertEqual(self.testData, df.select(df.key, df.value).collect()) + self.assertEqual([Row(value='1')], df.where(df.key == 1).select(df.value).collect()) + + def test_aggregator(self): + df = self.df + g = df.groupBy() + self.assertEqual([99, 100], sorted(g.agg({'key': 'max', 'value': 'count'}).collect()[0])) + self.assertEqual([Row(**{"AVG(key#0)": 49.5})], g.mean().collect()) + + from pyspark.sql import functions + self.assertEqual((0, u'99'), + tuple(g.agg(functions.first(df.key), functions.last(df.value)).first())) + self.assertTrue(95 < g.agg(functions.approxCountDistinct(df.key)).first()[0]) + self.assertEqual(100, g.agg(functions.countDistinct(df.value)).first()[0]) + + def test_save_and_load(self): + df = self.df + tmpPath = tempfile.mkdtemp() + shutil.rmtree(tmpPath) + df.save(tmpPath, "org.apache.spark.sql.json", "error") + actual = self.sqlCtx.load(tmpPath, "org.apache.spark.sql.json") + self.assertTrue(sorted(df.collect()) == sorted(actual.collect())) + + schema = StructType([StructField("value", StringType(), True)]) + actual = self.sqlCtx.load(tmpPath, "org.apache.spark.sql.json", schema) + self.assertTrue(sorted(df.select("value").collect()) == sorted(actual.collect())) + + df.save(tmpPath, "org.apache.spark.sql.json", "overwrite") + actual = self.sqlCtx.load(tmpPath, "org.apache.spark.sql.json") + self.assertTrue(sorted(df.collect()) == sorted(actual.collect())) + + df.save(source="org.apache.spark.sql.json", mode="overwrite", path=tmpPath, + noUse="this options will not be used in save.") + actual = self.sqlCtx.load(source="org.apache.spark.sql.json", path=tmpPath, + noUse="this options will not be used in load.") + self.assertTrue(sorted(df.collect()) == sorted(actual.collect())) + + defaultDataSourceName = self.sqlCtx.getConf("spark.sql.sources.default", + "org.apache.spark.sql.parquet") + self.sqlCtx.sql("SET spark.sql.sources.default=org.apache.spark.sql.json") + actual = self.sqlCtx.load(path=tmpPath) + self.assertTrue(sorted(df.collect()) == sorted(actual.collect())) + self.sqlCtx.sql("SET spark.sql.sources.default=" + defaultDataSourceName) + + shutil.rmtree(tmpPath) + + def test_help_command(self): + # Regression test for SPARK-5464 + rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}']) + df = self.sqlCtx.jsonRDD(rdd) + # render_doc() reproduces the help() exception without printing output + pydoc.render_doc(df) + pydoc.render_doc(df.foo) + pydoc.render_doc(df.take(1)) + + def test_infer_long_type(self): + longrow = [Row(f1='a', f2=100000000000000)] + df = self.sc.parallelize(longrow).toDF() + self.assertEqual(df.schema.fields[1].dataType, LongType()) + + # this saving as Parquet caused issues as well. + output_dir = os.path.join(self.tempdir.name, "infer_long_type") + df.saveAsParquetFile(output_dir) + df1 = self.sqlCtx.parquetFile(output_dir) + self.assertEquals('a', df1.first().f1) + self.assertEquals(100000000000000, df1.first().f2) + + self.assertEqual(_infer_type(1), LongType()) + self.assertEqual(_infer_type(2**10), LongType()) + self.assertEqual(_infer_type(2**20), LongType()) + self.assertEqual(_infer_type(2**31 - 1), LongType()) + self.assertEqual(_infer_type(2**31), LongType()) + self.assertEqual(_infer_type(2**61), LongType()) + self.assertEqual(_infer_type(2**71), LongType()) + + +class HiveContextSQLTests(ReusedPySparkTestCase): + + @classmethod + def setUpClass(cls): + ReusedPySparkTestCase.setUpClass() + cls.tempdir = tempfile.NamedTemporaryFile(delete=False) + try: + cls.sc._jvm.org.apache.hadoop.hive.conf.HiveConf() + except py4j.protocol.Py4JError: + cls.sqlCtx = None + return + os.unlink(cls.tempdir.name) + _scala_HiveContext =\ + cls.sc._jvm.org.apache.spark.sql.hive.test.TestHiveContext(cls.sc._jsc.sc()) + cls.sqlCtx = HiveContext(cls.sc, _scala_HiveContext) + cls.testData = [Row(key=i, value=str(i)) for i in range(100)] + cls.df = cls.sc.parallelize(cls.testData).toDF() + + @classmethod + def tearDownClass(cls): + ReusedPySparkTestCase.tearDownClass() + shutil.rmtree(cls.tempdir.name, ignore_errors=True) + + def test_save_and_load_table(self): + if self.sqlCtx is None: + return # no hive available, skipped + + df = self.df + tmpPath = tempfile.mkdtemp() + shutil.rmtree(tmpPath) + df.saveAsTable("savedJsonTable", "org.apache.spark.sql.json", "append", path=tmpPath) + actual = self.sqlCtx.createExternalTable("externalJsonTable", tmpPath, + "org.apache.spark.sql.json") + self.assertTrue( + sorted(df.collect()) == + sorted(self.sqlCtx.sql("SELECT * FROM savedJsonTable").collect())) + self.assertTrue( + sorted(df.collect()) == + sorted(self.sqlCtx.sql("SELECT * FROM externalJsonTable").collect())) + self.assertTrue(sorted(df.collect()) == sorted(actual.collect())) + self.sqlCtx.sql("DROP TABLE externalJsonTable") + + df.saveAsTable("savedJsonTable", "org.apache.spark.sql.json", "overwrite", path=tmpPath) + schema = StructType([StructField("value", StringType(), True)]) + actual = self.sqlCtx.createExternalTable("externalJsonTable", + source="org.apache.spark.sql.json", + schema=schema, path=tmpPath, + noUse="this options will not be used") + self.assertTrue( + sorted(df.collect()) == + sorted(self.sqlCtx.sql("SELECT * FROM savedJsonTable").collect())) + self.assertTrue( + sorted(df.select("value").collect()) == + sorted(self.sqlCtx.sql("SELECT * FROM externalJsonTable").collect())) + self.assertTrue(sorted(df.select("value").collect()) == sorted(actual.collect())) + self.sqlCtx.sql("DROP TABLE savedJsonTable") + self.sqlCtx.sql("DROP TABLE externalJsonTable") + + defaultDataSourceName = self.sqlCtx.getConf("spark.sql.sources.default", + "org.apache.spark.sql.parquet") + self.sqlCtx.sql("SET spark.sql.sources.default=org.apache.spark.sql.json") + df.saveAsTable("savedJsonTable", path=tmpPath, mode="overwrite") + actual = self.sqlCtx.createExternalTable("externalJsonTable", path=tmpPath) + self.assertTrue( + sorted(df.collect()) == + sorted(self.sqlCtx.sql("SELECT * FROM savedJsonTable").collect())) + self.assertTrue( + sorted(df.collect()) == + sorted(self.sqlCtx.sql("SELECT * FROM externalJsonTable").collect())) + self.assertTrue(sorted(df.collect()) == sorted(actual.collect())) + self.sqlCtx.sql("DROP TABLE savedJsonTable") + self.sqlCtx.sql("DROP TABLE externalJsonTable") + self.sqlCtx.sql("SET spark.sql.sources.default=" + defaultDataSourceName) + + shutil.rmtree(tmpPath) + +if __name__ == "__main__": + unittest.main() diff --git a/python/pyspark/sql.py b/python/pyspark/sql/types.py similarity index 50% rename from python/pyspark/sql.py rename to python/pyspark/sql/types.py index dcd3b60a6062b..0f5dc2be6dab8 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql/types.py @@ -15,21 +15,6 @@ # limitations under the License. # -""" -public classes of Spark SQL: - - - L{SQLContext} - Main entry point for SQL functionality. - - L{SchemaRDD} - A Resilient Distributed Dataset (RDD) with Schema information for the data contained. In - addition to normal RDD operations, SchemaRDDs also support SQL. - - L{Row} - A Row of data returned by a Spark SQL query. - - L{HiveContext} - Main entry point for accessing data stored in Apache Hive.. -""" - -import itertools import decimal import datetime import keyword @@ -38,23 +23,12 @@ import re from array import array from operator import itemgetter -from itertools import imap - -from py4j.protocol import Py4JError -from py4j.java_collections import ListConverter, MapConverter - -from pyspark.rdd import RDD -from pyspark.serializers import BatchedSerializer, AutoBatchedSerializer, PickleSerializer, \ - CloudPickleSerializer, UTF8Deserializer -from pyspark.storagelevel import StorageLevel -from pyspark.traceback_utils import SCCallSiteSync __all__ = [ - "StringType", "BinaryType", "BooleanType", "DateType", "TimestampType", "DecimalType", - "DoubleType", "FloatType", "ByteType", "IntegerType", "LongType", - "ShortType", "ArrayType", "MapType", "StructField", "StructType", - "SQLContext", "HiveContext", "SchemaRDD", "Row"] + "DataType", "NullType", "StringType", "BinaryType", "BooleanType", "DateType", + "TimestampType", "DecimalType", "DoubleType", "FloatType", "ByteType", "IntegerType", + "LongType", "ShortType", "ArrayType", "MapType", "StructField", "StructType"] class DataType(object): @@ -78,6 +52,9 @@ def __ne__(self, other): def typeName(cls): return cls.__name__[:-4].lower() + def simpleString(self): + return self.typeName() + def jsonValue(self): return self.typeName() @@ -171,6 +148,12 @@ def __init__(self, precision=None, scale=None): self.scale = scale self.hasPrecisionInfo = precision is not None + def simpleString(self): + if self.hasPrecisionInfo: + return "decimal(%d,%d)" % (self.precision, self.scale) + else: + return "decimal(10,0)" + def jsonValue(self): if self.hasPrecisionInfo: return "decimal(%d,%d)" % (self.precision, self.scale) @@ -206,6 +189,8 @@ class ByteType(PrimitiveType): The data type representing int values with 1 singed byte. """ + def simpleString(self): + return 'tinyint' class IntegerType(PrimitiveType): @@ -214,6 +199,8 @@ class IntegerType(PrimitiveType): The data type representing int values. """ + def simpleString(self): + return 'int' class LongType(PrimitiveType): @@ -224,6 +211,8 @@ class LongType(PrimitiveType): beyond the range of [-9223372036854775808, 9223372036854775807], please use DecimalType. """ + def simpleString(self): + return 'bigint' class ShortType(PrimitiveType): @@ -232,6 +221,8 @@ class ShortType(PrimitiveType): The data type representing int values with 2 signed bytes. """ + def simpleString(self): + return 'smallint' class ArrayType(DataType): @@ -259,6 +250,9 @@ def __init__(self, elementType, containsNull=True): self.elementType = elementType self.containsNull = containsNull + def simpleString(self): + return 'array<%s>' % self.elementType.simpleString() + def __repr__(self): return "ArrayType(%s,%s)" % (self.elementType, str(self.containsNull).lower()) @@ -309,6 +303,9 @@ def __init__(self, keyType, valueType, valueContainsNull=True): self.valueType = valueType self.valueContainsNull = valueContainsNull + def simpleString(self): + return 'map<%s,%s>' % (self.keyType.simpleString(), self.valueType.simpleString()) + def __repr__(self): return "MapType(%s,%s,%s)" % (self.keyType, self.valueType, str(self.valueContainsNull).lower()) @@ -363,6 +360,9 @@ def __init__(self, name, dataType, nullable=True, metadata=None): self.nullable = nullable self.metadata = metadata or {} + def simpleString(self): + return '%s:%s' % (self.name, self.dataType.simpleString()) + def __repr__(self): return "StructField(%s,%s,%s)" % (self.name, self.dataType, str(self.nullable).lower()) @@ -405,6 +405,9 @@ def __init__(self, fields): """ self.fields = fields + def simpleString(self): + return 'struct<%s>' % (','.join(f.simpleString() for f in self.fields)) + def __repr__(self): return ("StructType(List(%s))" % ",".join(str(field) for field in self.fields)) @@ -461,6 +464,9 @@ def deserialize(self, datum): """ raise NotImplementedError("UDT must implement deserialize().") + def simpleString(self): + return 'null' + def json(self): return json.dumps(self.jsonValue(), separators=(',', ':'), sort_keys=True) @@ -577,7 +583,7 @@ def _parse_datatype_json_value(json_value): _type_mappings = { type(None): NullType, bool: BooleanType, - int: IntegerType, + int: LongType, long: LongType, float: DoubleType, str: StringType, @@ -598,7 +604,7 @@ def _infer_type(obj): ExamplePointUDT """ if obj is None: - raise ValueError("Can not infer type for None") + return NullType() if hasattr(obj, '__UDT__'): return obj.__UDT__ @@ -631,15 +637,14 @@ def _infer_schema(row): if isinstance(row, dict): items = sorted(row.items()) - elif isinstance(row, tuple): + elif isinstance(row, (tuple, list)): if hasattr(row, "_fields"): # namedtuple items = zip(row._fields, tuple(row)) elif hasattr(row, "__FIELDS__"): # Row items = zip(row.__FIELDS__, tuple(row)) - elif all(isinstance(x, tuple) and len(x) == 2 for x in row): - items = row else: - raise ValueError("Can't infer schema from tuple") + names = ['_%d' % i for i in range(1, len(row) + 1)] + items = zip(names, row) elif hasattr(row, "__dict__"): # object items = sorted(row.__dict__.items()) @@ -806,17 +811,10 @@ def convert_struct(obj): if obj is None: return - if isinstance(obj, tuple): - if hasattr(obj, "_fields"): - d = dict(zip(obj._fields, obj)) - elif hasattr(obj, "__FIELDS__"): - d = dict(zip(obj.__FIELDS__, obj)) - elif all(isinstance(x, tuple) and len(x) == 2 for x in obj): - d = dict(obj) - else: - raise ValueError("unexpected tuple: %s" % str(obj)) + if isinstance(obj, (tuple, list)): + return tuple(conv(v) for v, conv in zip(obj, converters)) - elif isinstance(obj, dict): + if isinstance(obj, dict): d = obj elif hasattr(obj, "__dict__"): # object d = obj.__dict__ @@ -922,16 +920,16 @@ def _parse_schema_abstract(s): def _infer_schema_type(obj, dataType): """ - Fill the dataType with types infered from obj + Fill the dataType with types inferred from obj >>> schema = _parse_schema_abstract("a b c d") >>> row = (1, 1.0, "str", datetime.date(2014, 10, 10)) >>> _infer_schema_type(row, schema) - StructType...IntegerType...DoubleType...StringType...DateType... + StructType...LongType...DoubleType...StringType...DateType... >>> row = [[1], {"key": (1, 2.0)}] >>> schema = _parse_schema_abstract("a[] b{c d}") >>> _infer_schema_type(row, schema) - StructType...a,ArrayType...b,MapType(StringType,...c,IntegerType... + StructType...a,ArrayType...b,MapType(StringType,...c,LongType... """ if dataType is None: return _infer_type(obj) @@ -986,7 +984,7 @@ def _verify_type(obj, dataType): >>> _verify_type(None, StructType([])) >>> _verify_type("", StringType()) - >>> _verify_type(0, IntegerType()) + >>> _verify_type(0, LongType()) >>> _verify_type(range(3), ArrayType(ShortType())) >>> _verify_type(set(), ArrayType(StringType())) # doctest: +IGNORE_EXCEPTION_DETAIL Traceback (most recent call last): @@ -1016,7 +1014,7 @@ def _verify_type(obj, dataType): return _type = type(dataType) - assert _type in _acceptable_types, "unkown datatype: %s" % dataType + assert _type in _acceptable_types, "unknown datatype: %s" % dataType # subclass of them can not be deserialized in JVM if type(obj) not in _acceptable_types[_type]: @@ -1034,7 +1032,7 @@ def _verify_type(obj, dataType): elif isinstance(dataType, StructType): if len(obj) != len(dataType.fields): - raise ValueError("Length of object (%d) does not match with" + raise ValueError("Length of object (%d) does not match with " "length of fields (%d)" % (len(obj), len(dataType.fields))) for v, f in zip(obj, dataType.fields): _verify_type(v, f.dataType) @@ -1171,7 +1169,7 @@ def Dict(d): class Row(tuple): - """ Row in SchemaRDD """ + """ Row in DataFrame """ __DATATYPE__ = dataType __FIELDS__ = tuple(f.name for f in dataType.fields) __slots__ = () @@ -1194,534 +1192,6 @@ def __reduce__(self): return Row -class SQLContext(object): - - """Main entry point for Spark SQL functionality. - - A SQLContext can be used create L{SchemaRDD}, register L{SchemaRDD} as - tables, execute SQL over tables, cache tables, and read parquet files. - """ - - def __init__(self, sparkContext, sqlContext=None): - """Create a new SQLContext. - - :param sparkContext: The SparkContext to wrap. - :param sqlContext: An optional JVM Scala SQLContext. If set, we do not instatiate a new - SQLContext in the JVM, instead we make all calls to this object. - - >>> srdd = sqlCtx.inferSchema(rdd) - >>> sqlCtx.inferSchema(srdd) # doctest: +IGNORE_EXCEPTION_DETAIL - Traceback (most recent call last): - ... - TypeError:... - - >>> bad_rdd = sc.parallelize([1,2,3]) - >>> sqlCtx.inferSchema(bad_rdd) # doctest: +IGNORE_EXCEPTION_DETAIL - Traceback (most recent call last): - ... - ValueError:... - - >>> from datetime import datetime - >>> allTypes = sc.parallelize([Row(i=1, s="string", d=1.0, l=1L, - ... b=True, list=[1, 2, 3], dict={"s": 0}, row=Row(a=1), - ... time=datetime(2014, 8, 1, 14, 1, 5))]) - >>> srdd = sqlCtx.inferSchema(allTypes) - >>> srdd.registerTempTable("allTypes") - >>> sqlCtx.sql('select i+1, d+1, not b, list[1], dict["s"], time, row.a ' - ... 'from allTypes where b and i > 0').collect() - [Row(c0=2, c1=2.0, c2=False, c3=2, c4=0...8, 1, 14, 1, 5), a=1)] - >>> srdd.map(lambda x: (x.i, x.s, x.d, x.l, x.b, x.time, - ... x.row.a, x.list)).collect() - [(1, u'string', 1.0, 1, True, ...(2014, 8, 1, 14, 1, 5), 1, [1, 2, 3])] - """ - self._sc = sparkContext - self._jsc = self._sc._jsc - self._jvm = self._sc._jvm - self._scala_SQLContext = sqlContext - - @property - def _ssql_ctx(self): - """Accessor for the JVM Spark SQL context. - - Subclasses can override this property to provide their own - JVM Contexts. - """ - if self._scala_SQLContext is None: - self._scala_SQLContext = self._jvm.SQLContext(self._jsc.sc()) - return self._scala_SQLContext - - def registerFunction(self, name, f, returnType=StringType()): - """Registers a lambda function as a UDF so it can be used in SQL statements. - - In addition to a name and the function itself, the return type can be optionally specified. - When the return type is not given it default to a string and conversion will automatically - be done. For any other return type, the produced object must match the specified type. - - >>> sqlCtx.registerFunction("stringLengthString", lambda x: len(x)) - >>> sqlCtx.sql("SELECT stringLengthString('test')").collect() - [Row(c0=u'4')] - >>> sqlCtx.registerFunction("stringLengthInt", lambda x: len(x), IntegerType()) - >>> sqlCtx.sql("SELECT stringLengthInt('test')").collect() - [Row(c0=4)] - """ - func = lambda _, it: imap(lambda x: f(*x), it) - command = (func, None, - AutoBatchedSerializer(PickleSerializer()), - AutoBatchedSerializer(PickleSerializer())) - ser = CloudPickleSerializer() - pickled_command = ser.dumps(command) - if len(pickled_command) > (1 << 20): # 1M - broadcast = self._sc.broadcast(pickled_command) - pickled_command = ser.dumps(broadcast) - broadcast_vars = ListConverter().convert( - [x._jbroadcast for x in self._sc._pickled_broadcast_vars], - self._sc._gateway._gateway_client) - self._sc._pickled_broadcast_vars.clear() - env = MapConverter().convert(self._sc.environment, - self._sc._gateway._gateway_client) - includes = ListConverter().convert(self._sc._python_includes, - self._sc._gateway._gateway_client) - self._ssql_ctx.udf().registerPython(name, - bytearray(pickled_command), - env, - includes, - self._sc.pythonExec, - broadcast_vars, - self._sc._javaAccumulator, - returnType.json()) - - def inferSchema(self, rdd, samplingRatio=None): - """Infer and apply a schema to an RDD of L{Row}. - - When samplingRatio is specified, the schema is inferred by looking - at the types of each row in the sampled dataset. Otherwise, the - first 100 rows of the RDD are inspected. Nested collections are - supported, which can include array, dict, list, Row, tuple, - namedtuple, or object. - - Each row could be L{pyspark.sql.Row} object or namedtuple or objects. - Using top level dicts is deprecated, as dict is used to represent Maps. - - If a single column has multiple distinct inferred types, it may cause - runtime exceptions. - - >>> rdd = sc.parallelize( - ... [Row(field1=1, field2="row1"), - ... Row(field1=2, field2="row2"), - ... Row(field1=3, field2="row3")]) - >>> srdd = sqlCtx.inferSchema(rdd) - >>> srdd.collect()[0] - Row(field1=1, field2=u'row1') - - >>> NestedRow = Row("f1", "f2") - >>> nestedRdd1 = sc.parallelize([ - ... NestedRow(array('i', [1, 2]), {"row1": 1.0}), - ... NestedRow(array('i', [2, 3]), {"row2": 2.0})]) - >>> srdd = sqlCtx.inferSchema(nestedRdd1) - >>> srdd.collect() - [Row(f1=[1, 2], f2={u'row1': 1.0}), ..., f2={u'row2': 2.0})] - - >>> nestedRdd2 = sc.parallelize([ - ... NestedRow([[1, 2], [2, 3]], [1, 2]), - ... NestedRow([[2, 3], [3, 4]], [2, 3])]) - >>> srdd = sqlCtx.inferSchema(nestedRdd2) - >>> srdd.collect() - [Row(f1=[[1, 2], [2, 3]], f2=[1, 2]), ..., f2=[2, 3])] - - >>> from collections import namedtuple - >>> CustomRow = namedtuple('CustomRow', 'field1 field2') - >>> rdd = sc.parallelize( - ... [CustomRow(field1=1, field2="row1"), - ... CustomRow(field1=2, field2="row2"), - ... CustomRow(field1=3, field2="row3")]) - >>> srdd = sqlCtx.inferSchema(rdd) - >>> srdd.collect()[0] - Row(field1=1, field2=u'row1') - """ - - if isinstance(rdd, SchemaRDD): - raise TypeError("Cannot apply schema to SchemaRDD") - - first = rdd.first() - if not first: - raise ValueError("The first row in RDD is empty, " - "can not infer schema") - if type(first) is dict: - warnings.warn("Using RDD of dict to inferSchema is deprecated," - "please use pyspark.sql.Row instead") - - if samplingRatio is None: - schema = _infer_schema(first) - if _has_nulltype(schema): - for row in rdd.take(100)[1:]: - schema = _merge_type(schema, _infer_schema(row)) - if not _has_nulltype(schema): - break - else: - warnings.warn("Some of types cannot be determined by the " - "first 100 rows, please try again with sampling") - else: - if samplingRatio > 0.99: - rdd = rdd.sample(False, float(samplingRatio)) - schema = rdd.map(_infer_schema).reduce(_merge_type) - - converter = _create_converter(schema) - rdd = rdd.map(converter) - return self.applySchema(rdd, schema) - - def applySchema(self, rdd, schema): - """ - Applies the given schema to the given RDD of L{tuple} or L{list}. - - These tuples or lists can contain complex nested structures like - lists, maps or nested rows. - - The schema should be a StructType. - - It is important that the schema matches the types of the objects - in each row or exceptions could be thrown at runtime. - - >>> rdd2 = sc.parallelize([(1, "row1"), (2, "row2"), (3, "row3")]) - >>> schema = StructType([StructField("field1", IntegerType(), False), - ... StructField("field2", StringType(), False)]) - >>> srdd = sqlCtx.applySchema(rdd2, schema) - >>> sqlCtx.registerRDDAsTable(srdd, "table1") - >>> srdd2 = sqlCtx.sql("SELECT * from table1") - >>> srdd2.collect() - [Row(field1=1, field2=u'row1'),..., Row(field1=3, field2=u'row3')] - - >>> from datetime import date, datetime - >>> rdd = sc.parallelize([(127, -128L, -32768, 32767, 2147483647L, 1.0, - ... date(2010, 1, 1), - ... datetime(2010, 1, 1, 1, 1, 1), - ... {"a": 1}, (2,), [1, 2, 3], None)]) - >>> schema = StructType([ - ... StructField("byte1", ByteType(), False), - ... StructField("byte2", ByteType(), False), - ... StructField("short1", ShortType(), False), - ... StructField("short2", ShortType(), False), - ... StructField("int", IntegerType(), False), - ... StructField("float", FloatType(), False), - ... StructField("date", DateType(), False), - ... StructField("time", TimestampType(), False), - ... StructField("map", - ... MapType(StringType(), IntegerType(), False), False), - ... StructField("struct", - ... StructType([StructField("b", ShortType(), False)]), False), - ... StructField("list", ArrayType(ByteType(), False), False), - ... StructField("null", DoubleType(), True)]) - >>> srdd = sqlCtx.applySchema(rdd, schema) - >>> results = srdd.map( - ... lambda x: (x.byte1, x.byte2, x.short1, x.short2, x.int, x.float, x.date, - ... x.time, x.map["a"], x.struct.b, x.list, x.null)) - >>> results.collect()[0] # doctest: +NORMALIZE_WHITESPACE - (127, -128, -32768, 32767, 2147483647, 1.0, datetime.date(2010, 1, 1), - datetime.datetime(2010, 1, 1, 1, 1, 1), 1, 2, [1, 2, 3], None) - - >>> srdd.registerTempTable("table2") - >>> sqlCtx.sql( - ... "SELECT byte1 - 1 AS byte1, byte2 + 1 AS byte2, " + - ... "short1 + 1 AS short1, short2 - 1 AS short2, int - 1 AS int, " + - ... "float + 1.5 as float FROM table2").collect() - [Row(byte1=126, byte2=-127, short1=-32767, short2=32766, int=2147483646, float=2.5)] - - >>> rdd = sc.parallelize([(127, -32768, 1.0, - ... datetime(2010, 1, 1, 1, 1, 1), - ... {"a": 1}, (2,), [1, 2, 3])]) - >>> abstract = "byte short float time map{} struct(b) list[]" - >>> schema = _parse_schema_abstract(abstract) - >>> typedSchema = _infer_schema_type(rdd.first(), schema) - >>> srdd = sqlCtx.applySchema(rdd, typedSchema) - >>> srdd.collect() - [Row(byte=127, short=-32768, float=1.0, time=..., list=[1, 2, 3])] - """ - - if isinstance(rdd, SchemaRDD): - raise TypeError("Cannot apply schema to SchemaRDD") - - if not isinstance(schema, StructType): - raise TypeError("schema should be StructType") - - # take the first few rows to verify schema - rows = rdd.take(10) - # Row() cannot been deserialized by Pyrolite - if rows and isinstance(rows[0], tuple) and rows[0].__class__.__name__ == 'Row': - rdd = rdd.map(tuple) - rows = rdd.take(10) - - for row in rows: - _verify_type(row, schema) - - # convert python objects to sql data - converter = _python_to_sql_converter(schema) - rdd = rdd.map(converter) - - jrdd = self._jvm.SerDeUtil.toJavaArray(rdd._to_java_object_rdd()) - srdd = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.json()) - return SchemaRDD(srdd.toJavaSchemaRDD(), self) - - def registerRDDAsTable(self, rdd, tableName): - """Registers the given RDD as a temporary table in the catalog. - - Temporary tables exist only during the lifetime of this instance of - SQLContext. - - >>> srdd = sqlCtx.inferSchema(rdd) - >>> sqlCtx.registerRDDAsTable(srdd, "table1") - """ - if (rdd.__class__ is SchemaRDD): - srdd = rdd._jschema_rdd.baseSchemaRDD() - self._ssql_ctx.registerRDDAsTable(srdd, tableName) - else: - raise ValueError("Can only register SchemaRDD as table") - - def parquetFile(self, path): - """Loads a Parquet file, returning the result as a L{SchemaRDD}. - - >>> import tempfile, shutil - >>> parquetFile = tempfile.mkdtemp() - >>> shutil.rmtree(parquetFile) - >>> srdd = sqlCtx.inferSchema(rdd) - >>> srdd.saveAsParquetFile(parquetFile) - >>> srdd2 = sqlCtx.parquetFile(parquetFile) - >>> sorted(srdd.collect()) == sorted(srdd2.collect()) - True - """ - jschema_rdd = self._ssql_ctx.parquetFile(path).toJavaSchemaRDD() - return SchemaRDD(jschema_rdd, self) - - def jsonFile(self, path, schema=None, samplingRatio=1.0): - """ - Loads a text file storing one JSON object per line as a - L{SchemaRDD}. - - If the schema is provided, applies the given schema to this - JSON dataset. - - Otherwise, it samples the dataset with ratio `samplingRatio` to - determine the schema. - - >>> import tempfile, shutil - >>> jsonFile = tempfile.mkdtemp() - >>> shutil.rmtree(jsonFile) - >>> ofn = open(jsonFile, 'w') - >>> for json in jsonStrings: - ... print>>ofn, json - >>> ofn.close() - >>> srdd1 = sqlCtx.jsonFile(jsonFile) - >>> sqlCtx.registerRDDAsTable(srdd1, "table1") - >>> srdd2 = sqlCtx.sql( - ... "SELECT field1 AS f1, field2 as f2, field3 as f3, " - ... "field6 as f4 from table1") - >>> for r in srdd2.collect(): - ... print r - Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None) - Row(f1=2, f2=None, f3=Row(field4=22,..., f4=[Row(field7=u'row2')]) - Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None) - - >>> srdd3 = sqlCtx.jsonFile(jsonFile, srdd1.schema()) - >>> sqlCtx.registerRDDAsTable(srdd3, "table2") - >>> srdd4 = sqlCtx.sql( - ... "SELECT field1 AS f1, field2 as f2, field3 as f3, " - ... "field6 as f4 from table2") - >>> for r in srdd4.collect(): - ... print r - Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None) - Row(f1=2, f2=None, f3=Row(field4=22,..., f4=[Row(field7=u'row2')]) - Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None) - - >>> schema = StructType([ - ... StructField("field2", StringType(), True), - ... StructField("field3", - ... StructType([ - ... StructField("field5", - ... ArrayType(IntegerType(), False), True)]), False)]) - >>> srdd5 = sqlCtx.jsonFile(jsonFile, schema) - >>> sqlCtx.registerRDDAsTable(srdd5, "table3") - >>> srdd6 = sqlCtx.sql( - ... "SELECT field2 AS f1, field3.field5 as f2, " - ... "field3.field5[0] as f3 from table3") - >>> srdd6.collect() - [Row(f1=u'row1', f2=None, f3=None)...Row(f1=u'row3', f2=[], f3=None)] - """ - if schema is None: - srdd = self._ssql_ctx.jsonFile(path, samplingRatio) - else: - scala_datatype = self._ssql_ctx.parseDataType(schema.json()) - srdd = self._ssql_ctx.jsonFile(path, scala_datatype) - return SchemaRDD(srdd.toJavaSchemaRDD(), self) - - def jsonRDD(self, rdd, schema=None, samplingRatio=1.0): - """Loads an RDD storing one JSON object per string as a L{SchemaRDD}. - - If the schema is provided, applies the given schema to this - JSON dataset. - - Otherwise, it samples the dataset with ratio `samplingRatio` to - determine the schema. - - >>> srdd1 = sqlCtx.jsonRDD(json) - >>> sqlCtx.registerRDDAsTable(srdd1, "table1") - >>> srdd2 = sqlCtx.sql( - ... "SELECT field1 AS f1, field2 as f2, field3 as f3, " - ... "field6 as f4 from table1") - >>> for r in srdd2.collect(): - ... print r - Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None) - Row(f1=2, f2=None, f3=Row(field4=22..., f4=[Row(field7=u'row2')]) - Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None) - - >>> srdd3 = sqlCtx.jsonRDD(json, srdd1.schema()) - >>> sqlCtx.registerRDDAsTable(srdd3, "table2") - >>> srdd4 = sqlCtx.sql( - ... "SELECT field1 AS f1, field2 as f2, field3 as f3, " - ... "field6 as f4 from table2") - >>> for r in srdd4.collect(): - ... print r - Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None) - Row(f1=2, f2=None, f3=Row(field4=22..., f4=[Row(field7=u'row2')]) - Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None) - - >>> schema = StructType([ - ... StructField("field2", StringType(), True), - ... StructField("field3", - ... StructType([ - ... StructField("field5", - ... ArrayType(IntegerType(), False), True)]), False)]) - >>> srdd5 = sqlCtx.jsonRDD(json, schema) - >>> sqlCtx.registerRDDAsTable(srdd5, "table3") - >>> srdd6 = sqlCtx.sql( - ... "SELECT field2 AS f1, field3.field5 as f2, " - ... "field3.field5[0] as f3 from table3") - >>> srdd6.collect() - [Row(f1=u'row1', f2=None,...Row(f1=u'row3', f2=[], f3=None)] - - >>> sqlCtx.jsonRDD(sc.parallelize(['{}', - ... '{"key0": {"key1": "value1"}}'])).collect() - [Row(key0=None), Row(key0=Row(key1=u'value1'))] - >>> sqlCtx.jsonRDD(sc.parallelize(['{"key0": null}', - ... '{"key0": {"key1": "value1"}}'])).collect() - [Row(key0=None), Row(key0=Row(key1=u'value1'))] - """ - - def func(iterator): - for x in iterator: - if not isinstance(x, basestring): - x = unicode(x) - if isinstance(x, unicode): - x = x.encode("utf-8") - yield x - keyed = rdd.mapPartitions(func) - keyed._bypass_serializer = True - jrdd = keyed._jrdd.map(self._jvm.BytesToString()) - if schema is None: - srdd = self._ssql_ctx.jsonRDD(jrdd.rdd(), samplingRatio) - else: - scala_datatype = self._ssql_ctx.parseDataType(schema.json()) - srdd = self._ssql_ctx.jsonRDD(jrdd.rdd(), scala_datatype) - return SchemaRDD(srdd.toJavaSchemaRDD(), self) - - def sql(self, sqlQuery): - """Return a L{SchemaRDD} representing the result of the given query. - - >>> srdd = sqlCtx.inferSchema(rdd) - >>> sqlCtx.registerRDDAsTable(srdd, "table1") - >>> srdd2 = sqlCtx.sql("SELECT field1 AS f1, field2 as f2 from table1") - >>> srdd2.collect() - [Row(f1=1, f2=u'row1'), Row(f1=2, f2=u'row2'), Row(f1=3, f2=u'row3')] - """ - return SchemaRDD(self._ssql_ctx.sql(sqlQuery).toJavaSchemaRDD(), self) - - def table(self, tableName): - """Returns the specified table as a L{SchemaRDD}. - - >>> srdd = sqlCtx.inferSchema(rdd) - >>> sqlCtx.registerRDDAsTable(srdd, "table1") - >>> srdd2 = sqlCtx.table("table1") - >>> sorted(srdd.collect()) == sorted(srdd2.collect()) - True - """ - return SchemaRDD(self._ssql_ctx.table(tableName).toJavaSchemaRDD(), self) - - def cacheTable(self, tableName): - """Caches the specified table in-memory.""" - self._ssql_ctx.cacheTable(tableName) - - def uncacheTable(self, tableName): - """Removes the specified table from the in-memory cache.""" - self._ssql_ctx.uncacheTable(tableName) - - -class HiveContext(SQLContext): - - """A variant of Spark SQL that integrates with data stored in Hive. - - Configuration for Hive is read from hive-site.xml on the classpath. - It supports running both SQL and HiveQL commands. - """ - - def __init__(self, sparkContext, hiveContext=None): - """Create a new HiveContext. - - :param sparkContext: The SparkContext to wrap. - :param hiveContext: An optional JVM Scala HiveContext. If set, we do not instatiate a new - HiveContext in the JVM, instead we make all calls to this object. - """ - SQLContext.__init__(self, sparkContext) - - if hiveContext: - self._scala_HiveContext = hiveContext - - @property - def _ssql_ctx(self): - try: - if not hasattr(self, '_scala_HiveContext'): - self._scala_HiveContext = self._get_hive_ctx() - return self._scala_HiveContext - except Py4JError as e: - raise Exception("You must build Spark with Hive. " - "Export 'SPARK_HIVE=true' and run " - "build/sbt assembly", e) - - def _get_hive_ctx(self): - return self._jvm.HiveContext(self._jsc.sc()) - - def hiveql(self, hqlQuery): - """ - DEPRECATED: Use sql() - """ - warnings.warn("hiveql() is deprecated as the sql function now parses using HiveQL by" + - "default. The SQL dialect for parsing can be set using 'spark.sql.dialect'", - DeprecationWarning) - return SchemaRDD(self._ssql_ctx.hiveql(hqlQuery).toJavaSchemaRDD(), self) - - def hql(self, hqlQuery): - """ - DEPRECATED: Use sql() - """ - warnings.warn("hql() is deprecated as the sql function now parses using HiveQL by" + - "default. The SQL dialect for parsing can be set using 'spark.sql.dialect'", - DeprecationWarning) - return self.hiveql(hqlQuery) - - -class LocalHiveContext(HiveContext): - - def __init__(self, sparkContext, sqlContext=None): - HiveContext.__init__(self, sparkContext, sqlContext) - warnings.warn("LocalHiveContext is deprecated. " - "Use HiveContext instead.", DeprecationWarning) - - def _get_hive_ctx(self): - return self._jvm.LocalHiveContext(self._jsc.sc()) - - -class TestHiveContext(HiveContext): - - def _get_hive_ctx(self): - return self._jvm.TestHiveContext(self._jsc.sc()) - - def _create_row(fields, values): row = Row(*values) row.__FIELDS__ = fields @@ -1731,7 +1201,7 @@ def _create_row(fields, values): class Row(tuple): """ - A row in L{SchemaRDD}. The fields in it can be accessed like attributes. + A row in L{DataFrame}. The fields in it can be accessed like attributes. Row can be used to create a row object by using named arguments, the fields will be sorted by names. @@ -1809,350 +1279,21 @@ def __repr__(self): return "" % ", ".join(self) -def inherit_doc(cls): - for name, func in vars(cls).items(): - # only inherit docstring for public functions - if name.startswith("_"): - continue - if not func.__doc__: - for parent in cls.__bases__: - parent_func = getattr(parent, name, None) - if parent_func and getattr(parent_func, "__doc__", None): - func.__doc__ = parent_func.__doc__ - break - return cls - - -@inherit_doc -class SchemaRDD(RDD): - - """An RDD of L{Row} objects that has an associated schema. - - The underlying JVM object is a SchemaRDD, not a PythonRDD, so we can - utilize the relational query api exposed by Spark SQL. - - For normal L{pyspark.rdd.RDD} operations (map, count, etc.) the - L{SchemaRDD} is not operated on directly, as it's underlying - implementation is an RDD composed of Java objects. Instead it is - converted to a PythonRDD in the JVM, on which Python operations can - be done. - - This class receives raw tuples from Java but assigns a class to it in - all its data-collection methods (mapPartitionsWithIndex, collect, take, - etc) so that PySpark sees them as Row objects with named fields. - """ - - def __init__(self, jschema_rdd, sql_ctx): - self.sql_ctx = sql_ctx - self._sc = sql_ctx._sc - clsName = jschema_rdd.getClass().getName() - assert clsName.endswith("JavaSchemaRDD"), "jschema_rdd must be JavaSchemaRDD" - self._jschema_rdd = jschema_rdd - self._id = None - self.is_cached = False - self.is_checkpointed = False - self.ctx = self.sql_ctx._sc - # the _jrdd is created by javaToPython(), serialized by pickle - self._jrdd_deserializer = AutoBatchedSerializer(PickleSerializer()) - - @property - def _jrdd(self): - """Lazy evaluation of PythonRDD object. - - Only done when a user calls methods defined by the - L{pyspark.rdd.RDD} super class (map, filter, etc.). - """ - if not hasattr(self, '_lazy_jrdd'): - self._lazy_jrdd = self._jschema_rdd.baseSchemaRDD().javaToPython() - return self._lazy_jrdd - - def id(self): - if self._id is None: - self._id = self._jrdd.id() - return self._id - - def limit(self, num): - """Limit the result count to the number specified. - - >>> srdd = sqlCtx.inferSchema(rdd) - >>> srdd.limit(2).collect() - [Row(field1=1, field2=u'row1'), Row(field1=2, field2=u'row2')] - >>> srdd.limit(0).collect() - [] - """ - rdd = self._jschema_rdd.baseSchemaRDD().limit(num).toJavaSchemaRDD() - return SchemaRDD(rdd, self.sql_ctx) - - def toJSON(self, use_unicode=False): - """Convert a SchemaRDD into a MappedRDD of JSON documents; one document per row. - - >>> srdd1 = sqlCtx.jsonRDD(json) - >>> sqlCtx.registerRDDAsTable(srdd1, "table1") - >>> srdd2 = sqlCtx.sql( "SELECT * from table1") - >>> srdd2.toJSON().take(1)[0] == '{"field1":1,"field2":"row1","field3":{"field4":11}}' - True - >>> srdd3 = sqlCtx.sql( "SELECT field3.field4 from table1") - >>> srdd3.toJSON().collect() == ['{"field4":11}', '{"field4":22}', '{"field4":33}'] - True - """ - rdd = self._jschema_rdd.baseSchemaRDD().toJSON() - return RDD(rdd.toJavaRDD(), self._sc, UTF8Deserializer(use_unicode)) - - def saveAsParquetFile(self, path): - """Save the contents as a Parquet file, preserving the schema. - - Files that are written out using this method can be read back in as - a SchemaRDD using the L{SQLContext.parquetFile} method. - - >>> import tempfile, shutil - >>> parquetFile = tempfile.mkdtemp() - >>> shutil.rmtree(parquetFile) - >>> srdd = sqlCtx.inferSchema(rdd) - >>> srdd.saveAsParquetFile(parquetFile) - >>> srdd2 = sqlCtx.parquetFile(parquetFile) - >>> sorted(srdd2.collect()) == sorted(srdd.collect()) - True - """ - self._jschema_rdd.saveAsParquetFile(path) - - def registerTempTable(self, name): - """Registers this RDD as a temporary table using the given name. - - The lifetime of this temporary table is tied to the L{SQLContext} - that was used to create this SchemaRDD. - - >>> srdd = sqlCtx.inferSchema(rdd) - >>> srdd.registerTempTable("test") - >>> srdd2 = sqlCtx.sql("select * from test") - >>> sorted(srdd.collect()) == sorted(srdd2.collect()) - True - """ - self._jschema_rdd.registerTempTable(name) - - def registerAsTable(self, name): - """DEPRECATED: use registerTempTable() instead""" - warnings.warn("Use registerTempTable instead of registerAsTable.", DeprecationWarning) - self.registerTempTable(name) - - def insertInto(self, tableName, overwrite=False): - """Inserts the contents of this SchemaRDD into the specified table. - - Optionally overwriting any existing data. - """ - self._jschema_rdd.insertInto(tableName, overwrite) - - def saveAsTable(self, tableName): - """Creates a new table with the contents of this SchemaRDD.""" - self._jschema_rdd.saveAsTable(tableName) - - def schema(self): - """Returns the schema of this SchemaRDD (represented by - a L{StructType}).""" - return _parse_datatype_json_string(self._jschema_rdd.baseSchemaRDD().schema().json()) - - def schemaString(self): - """Returns the output schema in the tree format.""" - return self._jschema_rdd.schemaString() - - def printSchema(self): - """Prints out the schema in the tree format.""" - print self.schemaString() - - def count(self): - """Return the number of elements in this RDD. - - Unlike the base RDD implementation of count, this implementation - leverages the query optimizer to compute the count on the SchemaRDD, - which supports features such as filter pushdown. - - >>> srdd = sqlCtx.inferSchema(rdd) - >>> srdd.count() - 3L - >>> srdd.count() == srdd.map(lambda x: x).count() - True - """ - return self._jschema_rdd.count() - - def collect(self): - """Return a list that contains all of the rows in this RDD. - - Each object in the list is a Row, the fields can be accessed as - attributes. - - Unlike the base RDD implementation of collect, this implementation - leverages the query optimizer to perform a collect on the SchemaRDD, - which supports features such as filter pushdown. - - >>> srdd = sqlCtx.inferSchema(rdd) - >>> srdd.collect() - [Row(field1=1, field2=u'row1'), ..., Row(field1=3, field2=u'row3')] - """ - with SCCallSiteSync(self.context) as css: - bytesInJava = self._jschema_rdd.baseSchemaRDD().collectToPython().iterator() - cls = _create_cls(self.schema()) - return map(cls, self._collect_iterator_through_file(bytesInJava)) - - def take(self, num): - """Take the first num rows of the RDD. - - Each object in the list is a Row, the fields can be accessed as - attributes. - - Unlike the base RDD implementation of take, this implementation - leverages the query optimizer to perform a collect on a SchemaRDD, - which supports features such as filter pushdown. - - >>> srdd = sqlCtx.inferSchema(rdd) - >>> srdd.take(2) - [Row(field1=1, field2=u'row1'), Row(field1=2, field2=u'row2')] - """ - return self.limit(num).collect() - - # Convert each object in the RDD to a Row with the right class - # for this SchemaRDD, so that fields can be accessed as attributes. - def mapPartitionsWithIndex(self, f, preservesPartitioning=False): - """ - Return a new RDD by applying a function to each partition of this RDD, - while tracking the index of the original partition. - - >>> rdd = sc.parallelize([1, 2, 3, 4], 4) - >>> def f(splitIndex, iterator): yield splitIndex - >>> rdd.mapPartitionsWithIndex(f).sum() - 6 - """ - rdd = RDD(self._jrdd, self._sc, self._jrdd_deserializer) - - schema = self.schema() - - def applySchema(_, it): - cls = _create_cls(schema) - return itertools.imap(cls, it) - - objrdd = rdd.mapPartitionsWithIndex(applySchema, preservesPartitioning) - return objrdd.mapPartitionsWithIndex(f, preservesPartitioning) - - # We override the default cache/persist/checkpoint behavior - # as we want to cache the underlying SchemaRDD object in the JVM, - # not the PythonRDD checkpointed by the super class - def cache(self): - self.is_cached = True - self._jschema_rdd.cache() - return self - - def persist(self, storageLevel=StorageLevel.MEMORY_ONLY_SER): - self.is_cached = True - javaStorageLevel = self.ctx._getJavaStorageLevel(storageLevel) - self._jschema_rdd.persist(javaStorageLevel) - return self - - def unpersist(self, blocking=True): - self.is_cached = False - self._jschema_rdd.unpersist(blocking) - return self - - def checkpoint(self): - self.is_checkpointed = True - self._jschema_rdd.checkpoint() - - def isCheckpointed(self): - return self._jschema_rdd.isCheckpointed() - - def getCheckpointFile(self): - checkpointFile = self._jschema_rdd.getCheckpointFile() - if checkpointFile.isPresent(): - return checkpointFile.get() - - def coalesce(self, numPartitions, shuffle=False): - rdd = self._jschema_rdd.coalesce(numPartitions, shuffle) - return SchemaRDD(rdd, self.sql_ctx) - - def distinct(self, numPartitions=None): - if numPartitions is None: - rdd = self._jschema_rdd.distinct() - else: - rdd = self._jschema_rdd.distinct(numPartitions) - return SchemaRDD(rdd, self.sql_ctx) - - def intersection(self, other): - if (other.__class__ is SchemaRDD): - rdd = self._jschema_rdd.intersection(other._jschema_rdd) - return SchemaRDD(rdd, self.sql_ctx) - else: - raise ValueError("Can only intersect with another SchemaRDD") - - def repartition(self, numPartitions): - rdd = self._jschema_rdd.repartition(numPartitions) - return SchemaRDD(rdd, self.sql_ctx) - - def subtract(self, other, numPartitions=None): - if (other.__class__ is SchemaRDD): - if numPartitions is None: - rdd = self._jschema_rdd.subtract(other._jschema_rdd) - else: - rdd = self._jschema_rdd.subtract(other._jschema_rdd, - numPartitions) - return SchemaRDD(rdd, self.sql_ctx) - else: - raise ValueError("Can only subtract another SchemaRDD") - - def sample(self, withReplacement, fraction, seed=None): - """ - Return a sampled subset of this SchemaRDD. - - >>> srdd = sqlCtx.inferSchema(rdd) - >>> srdd.sample(False, 0.5, 97).count() - 2L - """ - assert fraction >= 0.0, "Negative fraction value: %s" % fraction - seed = seed if seed is not None else random.randint(0, sys.maxint) - rdd = self._jschema_rdd.sample(withReplacement, fraction, long(seed)) - return SchemaRDD(rdd, self.sql_ctx) - - def takeSample(self, withReplacement, num, seed=None): - """Return a fixed-size sampled subset of this SchemaRDD. - - >>> srdd = sqlCtx.inferSchema(rdd) - >>> srdd.takeSample(False, 2, 97) - [Row(field1=3, field2=u'row3'), Row(field1=1, field2=u'row1')] - """ - seed = seed if seed is not None else random.randint(0, sys.maxint) - with SCCallSiteSync(self.context) as css: - bytesInJava = self._jschema_rdd.baseSchemaRDD() \ - .takeSampleToPython(withReplacement, num, long(seed)) \ - .iterator() - cls = _create_cls(self.schema()) - return map(cls, self._collect_iterator_through_file(bytesInJava)) - - def _test(): import doctest from pyspark.context import SparkContext - # let doctest run in pyspark.sql, so DataTypes can be picklable - import pyspark.sql + # let doctest run in pyspark.sql.types, so DataTypes can be picklable + import pyspark.sql.types from pyspark.sql import Row, SQLContext - from pyspark.tests import ExamplePoint, ExamplePointUDT - globs = pyspark.sql.__dict__.copy() + from pyspark.sql.tests import ExamplePoint, ExamplePointUDT + globs = pyspark.sql.types.__dict__.copy() sc = SparkContext('local[4]', 'PythonTest') globs['sc'] = sc - globs['sqlCtx'] = SQLContext(sc) - globs['rdd'] = sc.parallelize( - [Row(field1=1, field2="row1"), - Row(field1=2, field2="row2"), - Row(field1=3, field2="row3")] - ) + globs['sqlCtx'] = sqlCtx = SQLContext(sc) globs['ExamplePoint'] = ExamplePoint globs['ExamplePointUDT'] = ExamplePointUDT - jsonStrings = [ - '{"field1": 1, "field2": "row1", "field3":{"field4":11}}', - '{"field1" : 2, "field3":{"field4":22, "field5": [10, 11]},' - '"field6":[{"field7": "row2"}]}', - '{"field1" : null, "field2": "row3", ' - '"field3":{"field4":33, "field5": []}}' - ] - globs['jsonStrings'] = jsonStrings - globs['json'] = sc.parallelize(jsonStrings) (failure_count, test_count) = doctest.testmod( - pyspark.sql, globs=globs, optionflags=doctest.ELLIPSIS) + pyspark.sql.types, globs=globs, optionflags=doctest.ELLIPSIS) globs['sc'].stop() if failure_count: exit(-1) diff --git a/python/pyspark/status.py b/python/pyspark/status.py new file mode 100644 index 0000000000000..a6fa7dd3144d4 --- /dev/null +++ b/python/pyspark/status.py @@ -0,0 +1,96 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from collections import namedtuple + +__all__ = ["SparkJobInfo", "SparkStageInfo", "StatusTracker"] + + +class SparkJobInfo(namedtuple("SparkJobInfo", "jobId stageIds status")): + """ + Exposes information about Spark Jobs. + """ + + +class SparkStageInfo(namedtuple("SparkStageInfo", + "stageId currentAttemptId name numTasks numActiveTasks " + "numCompletedTasks numFailedTasks")): + """ + Exposes information about Spark Stages. + """ + + +class StatusTracker(object): + """ + Low-level status reporting APIs for monitoring job and stage progress. + + These APIs intentionally provide very weak consistency semantics; + consumers of these APIs should be prepared to handle empty / missing + information. For example, a job's stage ids may be known but the status + API may not have any information about the details of those stages, so + `getStageInfo` could potentially return `None` for a valid stage id. + + To limit memory usage, these APIs only provide information on recent + jobs / stages. These APIs will provide information for the last + `spark.ui.retainedStages` stages and `spark.ui.retainedJobs` jobs. + """ + def __init__(self, jtracker): + self._jtracker = jtracker + + def getJobIdsForGroup(self, jobGroup=None): + """ + Return a list of all known jobs in a particular job group. If + `jobGroup` is None, then returns all known jobs that are not + associated with a job group. + + The returned list may contain running, failed, and completed jobs, + and may vary across invocations of this method. This method does + not guarantee the order of the elements in its result. + """ + return list(self._jtracker.getJobIdsForGroup(jobGroup)) + + def getActiveStageIds(self): + """ + Returns an array containing the ids of all active stages. + """ + return sorted(list(self._jtracker.getActiveStageIds())) + + def getActiveJobsIds(self): + """ + Returns an array containing the ids of all active jobs. + """ + return sorted((list(self._jtracker.getActiveJobIds()))) + + def getJobInfo(self, jobId): + """ + Returns a :class:`SparkJobInfo` object, or None if the job info + could not be found or was garbage collected. + """ + job = self._jtracker.getJobInfo(jobId) + if job is not None: + return SparkJobInfo(jobId, job.stageIds(), str(job.status())) + + def getStageInfo(self, stageId): + """ + Returns a :class:`SparkStageInfo` object, or None if the stage + info could not be found or was garbage collected. + """ + stage = self._jtracker.getStageInfo(stageId) + if stage is not None: + # TODO: fetch them in batch for better performance + attrs = [getattr(stage, f)() for f in SparkStageInfo._fields[1:]] + return SparkStageInfo(stageId, *attrs) diff --git a/python/pyspark/streaming/context.py b/python/pyspark/streaming/context.py index d48f3598e33b2..2c73083c9f9a8 100644 --- a/python/pyspark/streaming/context.py +++ b/python/pyspark/streaming/context.py @@ -21,7 +21,7 @@ from py4j.java_gateway import java_import, JavaObject from pyspark import RDD, SparkConf -from pyspark.serializers import UTF8Deserializer, CloudPickleSerializer +from pyspark.serializers import NoOpSerializer, UTF8Deserializer, CloudPickleSerializer from pyspark.context import SparkContext from pyspark.storagelevel import StorageLevel from pyspark.streaming.dstream import DStream @@ -189,7 +189,16 @@ def awaitTermination(self, timeout=None): if timeout is None: self._jssc.awaitTermination() else: - self._jssc.awaitTermination(int(timeout * 1000)) + self._jssc.awaitTerminationOrTimeout(int(timeout * 1000)) + + def awaitTerminationOrTimeout(self, timeout): + """ + Wait for the execution to stop. Return `true` if it's stopped; or + throw the reported error during the execution; or `false` if the + waiting time elapsed before returning from the method. + @param timeout: time to wait in seconds + """ + self._jssc.awaitTerminationOrTimeout(int(timeout * 1000)) def stop(self, stopSparkContext=True, stopGraceFully=False): """ @@ -251,6 +260,20 @@ def textFileStream(self, directory): """ return DStream(self._jssc.textFileStream(directory), self, UTF8Deserializer()) + def binaryRecordsStream(self, directory, recordLength): + """ + Create an input stream that monitors a Hadoop-compatible file system + for new files and reads them as flat binary files with records of + fixed length. Files must be written to the monitored directory by "moving" + them from another location within the same file system. + File names starting with . are ignored. + + @param directory: Directory to load data from + @param recordLength: Length of each record in bytes + """ + return DStream(self._jssc.binaryRecordsStream(directory, recordLength), self, + NoOpSerializer()) + def _check_serializers(self, rdds): # make sure they have same serializer if len(set(rdd._jrdd_deserializer for rdd in rdds)) > 1: diff --git a/python/pyspark/streaming/dstream.py b/python/pyspark/streaming/dstream.py index 2fe39392ff081..3fa42444239f7 100644 --- a/python/pyspark/streaming/dstream.py +++ b/python/pyspark/streaming/dstream.py @@ -578,7 +578,7 @@ def reduceFunc(t, a, b): if a is None: g = b.groupByKey(numPartitions).mapValues(lambda vs: (list(vs), None)) else: - g = a.cogroup(b, numPartitions) + g = a.cogroup(b.partitionBy(numPartitions), numPartitions) g = g.mapValues(lambda (va, vb): (list(vb), list(va)[0] if len(va) else None)) state = g.mapValues(lambda (vs, s): updateFunc(vs, s)) return state.filter(lambda (k, v): v is not None) diff --git a/python/pyspark/streaming/kafka.py b/python/pyspark/streaming/kafka.py new file mode 100644 index 0000000000000..19ad71f99d4d5 --- /dev/null +++ b/python/pyspark/streaming/kafka.py @@ -0,0 +1,83 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from py4j.java_collections import MapConverter +from py4j.java_gateway import java_import, Py4JError + +from pyspark.storagelevel import StorageLevel +from pyspark.serializers import PairDeserializer, NoOpSerializer +from pyspark.streaming import DStream + +__all__ = ['KafkaUtils', 'utf8_decoder'] + + +def utf8_decoder(s): + """ Decode the unicode as UTF-8 """ + return s and s.decode('utf-8') + + +class KafkaUtils(object): + + @staticmethod + def createStream(ssc, zkQuorum, groupId, topics, kafkaParams={}, + storageLevel=StorageLevel.MEMORY_AND_DISK_SER_2, + keyDecoder=utf8_decoder, valueDecoder=utf8_decoder): + """ + Create an input stream that pulls messages from a Kafka Broker. + + :param ssc: StreamingContext object + :param zkQuorum: Zookeeper quorum (hostname:port,hostname:port,..). + :param groupId: The group id for this consumer. + :param topics: Dict of (topic_name -> numPartitions) to consume. + Each partition is consumed in its own thread. + :param kafkaParams: Additional params for Kafka + :param storageLevel: RDD storage level. + :param keyDecoder: A function used to decode key (default is utf8_decoder) + :param valueDecoder: A function used to decode value (default is utf8_decoder) + :return: A DStream object + """ + java_import(ssc._jvm, "org.apache.spark.streaming.kafka.KafkaUtils") + + kafkaParams.update({ + "zookeeper.connect": zkQuorum, + "group.id": groupId, + "zookeeper.connection.timeout.ms": "10000", + }) + if not isinstance(topics, dict): + raise TypeError("topics should be dict") + jtopics = MapConverter().convert(topics, ssc.sparkContext._gateway._gateway_client) + jparam = MapConverter().convert(kafkaParams, ssc.sparkContext._gateway._gateway_client) + jlevel = ssc._sc._getJavaStorageLevel(storageLevel) + + def getClassByName(name): + return ssc._jvm.org.apache.spark.util.Utils.classForName(name) + + try: + array = getClassByName("[B") + decoder = getClassByName("kafka.serializer.DefaultDecoder") + jstream = ssc._jvm.KafkaUtils.createStream(ssc._jssc, array, array, decoder, decoder, + jparam, jtopics, jlevel) + except Py4JError, e: + # TODO: use --jar once it also work on driver + if not e.message or 'call a package' in e.message: + print "No kafka package, please put the assembly jar into classpath:" + print " $ bin/spark-submit --driver-class-path external/kafka-assembly/target/" + \ + "scala-*/spark-streaming-kafka-assembly-*.jar" + raise e + ser = PairDeserializer(NoOpSerializer(), NoOpSerializer()) + stream = DStream(jstream, ssc, ser) + return stream.map(lambda (k, v): (keyDecoder(k), valueDecoder(v))) diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index a8d876d0fa3b3..608f8e26473a6 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -21,6 +21,7 @@ import operator import unittest import tempfile +import struct from pyspark.context import SparkConf, SparkContext, RDD from pyspark.streaming.context import StreamingContext @@ -455,6 +456,20 @@ def test_text_file_stream(self): self.wait_for(result, 2) self.assertEqual([range(10), range(10)], result) + def test_binary_records_stream(self): + d = tempfile.mkdtemp() + self.ssc = StreamingContext(self.sc, self.duration) + dstream = self.ssc.binaryRecordsStream(d, 10).map( + lambda v: struct.unpack("10b", str(v))) + result = self._collect(dstream, 2, block=False) + self.ssc.start() + for name in ('a', 'b'): + time.sleep(1) + with open(os.path.join(d, name), "wb") as f: + f.write(bytearray(range(10))) + self.wait_for(result, 2) + self.assertEqual([range(10), range(10)], map(lambda v: list(v[0]), result)) + def test_union(self): input = [range(i + 1) for i in range(3)] dstream = self.ssc.queueStream(input) diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index b474fcf5bfb7e..06ba2b461d53e 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -46,13 +46,13 @@ from pyspark.conf import SparkConf from pyspark.context import SparkContext +from pyspark.rdd import RDD from pyspark.files import SparkFiles from pyspark.serializers import read_int, BatchedSerializer, MarshalSerializer, PickleSerializer, \ - CloudPickleSerializer, CompressedSerializer + CloudPickleSerializer, CompressedSerializer, UTF8Deserializer, NoOpSerializer from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, ExternalSorter -from pyspark.sql import SQLContext, IntegerType, Row, ArrayType, StructType, StructField, \ - UserDefinedType, DoubleType from pyspark import shuffle +from pyspark.profiler import BasicProfiler _have_scipy = False _have_numpy = False @@ -543,6 +543,12 @@ def test_zip_with_different_serializers(self): # regression test for bug in _reserializer() self.assertEqual(cnt, t.zip(rdd).count()) + def test_zip_with_different_object_sizes(self): + # regress test for SPARK-5973 + a = self.sc.parallelize(range(10000)).map(lambda i: '*' * i) + b = self.sc.parallelize(range(10000, 20000)).map(lambda i: '*' * i) + self.assertEqual(10000, a.zip(b).count()) + def test_zip_with_different_number_of_items(self): a = self.sc.parallelize(range(5), 2) # different number of partitions @@ -714,6 +720,68 @@ def test_sample(self): wr_s21 = rdd.sample(True, 0.4, 21).collect() self.assertNotEqual(set(wr_s11), set(wr_s21)) + def test_null_in_rdd(self): + jrdd = self.sc._jvm.PythonUtils.generateRDDWithNull(self.sc._jsc) + rdd = RDD(jrdd, self.sc, UTF8Deserializer()) + self.assertEqual([u"a", None, u"b"], rdd.collect()) + rdd = RDD(jrdd, self.sc, NoOpSerializer()) + self.assertEqual(["a", None, "b"], rdd.collect()) + + def test_multiple_python_java_RDD_conversions(self): + # Regression test for SPARK-5361 + data = [ + (u'1', {u'director': u'David Lean'}), + (u'2', {u'director': u'Andrew Dominik'}) + ] + data_rdd = self.sc.parallelize(data) + data_java_rdd = data_rdd._to_java_object_rdd() + data_python_rdd = self.sc._jvm.SerDe.javaToPython(data_java_rdd) + converted_rdd = RDD(data_python_rdd, self.sc) + self.assertEqual(2, converted_rdd.count()) + + # conversion between python and java RDD threw exceptions + data_java_rdd = converted_rdd._to_java_object_rdd() + data_python_rdd = self.sc._jvm.SerDe.javaToPython(data_java_rdd) + converted_rdd = RDD(data_python_rdd, self.sc) + self.assertEqual(2, converted_rdd.count()) + + def test_narrow_dependency_in_join(self): + rdd = self.sc.parallelize(range(10)).map(lambda x: (x, x)) + parted = rdd.partitionBy(2) + self.assertEqual(2, parted.union(parted).getNumPartitions()) + self.assertEqual(rdd.getNumPartitions() + 2, parted.union(rdd).getNumPartitions()) + self.assertEqual(rdd.getNumPartitions() + 2, rdd.union(parted).getNumPartitions()) + + self.sc.setJobGroup("test1", "test", True) + tracker = self.sc.statusTracker() + + d = sorted(parted.join(parted).collect()) + self.assertEqual(10, len(d)) + self.assertEqual((0, (0, 0)), d[0]) + jobId = tracker.getJobIdsForGroup("test1")[0] + self.assertEqual(2, len(tracker.getJobInfo(jobId).stageIds)) + + self.sc.setJobGroup("test2", "test", True) + d = sorted(parted.join(rdd).collect()) + self.assertEqual(10, len(d)) + self.assertEqual((0, (0, 0)), d[0]) + jobId = tracker.getJobIdsForGroup("test2")[0] + self.assertEqual(3, len(tracker.getJobInfo(jobId).stageIds)) + + self.sc.setJobGroup("test3", "test", True) + d = sorted(parted.cogroup(parted).collect()) + self.assertEqual(10, len(d)) + self.assertEqual([[0], [0]], map(list, d[0][1])) + jobId = tracker.getJobIdsForGroup("test3")[0] + self.assertEqual(2, len(tracker.getJobInfo(jobId).stageIds)) + + self.sc.setJobGroup("test4", "test", True) + d = sorted(parted.cogroup(rdd).collect()) + self.assertEqual(10, len(d)) + self.assertEqual([[0], [0]], map(list, d[0][1])) + jobId = tracker.getJobIdsForGroup("test4")[0] + self.assertEqual(3, len(tracker.getJobInfo(jobId).stageIds)) + class ProfilerTests(PySparkTestCase): @@ -724,16 +792,12 @@ def setUp(self): self.sc = SparkContext('local[4]', class_name, conf=conf) def test_profiler(self): + self.do_computation() - def heavy_foo(x): - for i in range(1 << 20): - x = 1 - rdd = self.sc.parallelize(range(100)) - rdd.foreach(heavy_foo) - profiles = self.sc._profile_stats - self.assertEqual(1, len(profiles)) - id, acc, _ = profiles[0] - stats = acc.value + profilers = self.sc.profiler_collector.profilers + self.assertEqual(1, len(profilers)) + id, profiler, _ = profilers[0] + stats = profiler.stats() self.assertTrue(stats is not None) width, stat_list = stats.get_print_list([]) func_names = [func_name for fname, n, func_name in stat_list] @@ -744,235 +808,30 @@ def heavy_foo(x): self.sc.dump_profiles(d) self.assertTrue("rdd_%d.pstats" % id in os.listdir(d)) + def test_custom_profiler(self): + class TestCustomProfiler(BasicProfiler): + def show(self, id): + self.result = "Custom formatting" -class ExamplePointUDT(UserDefinedType): - """ - User-defined type (UDT) for ExamplePoint. - """ - - @classmethod - def sqlType(self): - return ArrayType(DoubleType(), False) - - @classmethod - def module(cls): - return 'pyspark.tests' - - @classmethod - def scalaUDT(cls): - return 'org.apache.spark.sql.test.ExamplePointUDT' - - def serialize(self, obj): - return [obj.x, obj.y] - - def deserialize(self, datum): - return ExamplePoint(datum[0], datum[1]) - - -class ExamplePoint: - """ - An example class to demonstrate UDT in Scala, Java, and Python. - """ - - __UDT__ = ExamplePointUDT() - - def __init__(self, x, y): - self.x = x - self.y = y + self.sc.profiler_collector.profiler_cls = TestCustomProfiler - def __repr__(self): - return "ExamplePoint(%s,%s)" % (self.x, self.y) + self.do_computation() - def __str__(self): - return "(%s,%s)" % (self.x, self.y) + profilers = self.sc.profiler_collector.profilers + self.assertEqual(1, len(profilers)) + _, profiler, _ = profilers[0] + self.assertTrue(isinstance(profiler, TestCustomProfiler)) - def __eq__(self, other): - return isinstance(other, ExamplePoint) and \ - other.x == self.x and other.y == self.y - - -class SQLTests(ReusedPySparkTestCase): - - @classmethod - def setUpClass(cls): - ReusedPySparkTestCase.setUpClass() - cls.tempdir = tempfile.NamedTemporaryFile(delete=False) - os.unlink(cls.tempdir.name) - - @classmethod - def tearDownClass(cls): - ReusedPySparkTestCase.tearDownClass() - shutil.rmtree(cls.tempdir.name, ignore_errors=True) - - def setUp(self): - self.sqlCtx = SQLContext(self.sc) - - def test_udf(self): - self.sqlCtx.registerFunction("twoArgs", lambda x, y: len(x) + y, IntegerType()) - [row] = self.sqlCtx.sql("SELECT twoArgs('test', 1)").collect() - self.assertEqual(row[0], 5) - - def test_udf2(self): - self.sqlCtx.registerFunction("strlen", lambda string: len(string), IntegerType()) - self.sqlCtx.inferSchema(self.sc.parallelize([Row(a="test")])).registerTempTable("test") - [res] = self.sqlCtx.sql("SELECT strlen(a) FROM test WHERE strlen(a) > 1").collect() - self.assertEqual(4, res[0]) - - def test_udf_with_array_type(self): - d = [Row(l=range(3), d={"key": range(5)})] - rdd = self.sc.parallelize(d) - srdd = self.sqlCtx.inferSchema(rdd).registerTempTable("test") - self.sqlCtx.registerFunction("copylist", lambda l: list(l), ArrayType(IntegerType())) - self.sqlCtx.registerFunction("maplen", lambda d: len(d), IntegerType()) - [(l1, l2)] = self.sqlCtx.sql("select copylist(l), maplen(d) from test").collect() - self.assertEqual(range(3), l1) - self.assertEqual(1, l2) - - def test_broadcast_in_udf(self): - bar = {"a": "aa", "b": "bb", "c": "abc"} - foo = self.sc.broadcast(bar) - self.sqlCtx.registerFunction("MYUDF", lambda x: foo.value[x] if x else '') - [res] = self.sqlCtx.sql("SELECT MYUDF('c')").collect() - self.assertEqual("abc", res[0]) - [res] = self.sqlCtx.sql("SELECT MYUDF('')").collect() - self.assertEqual("", res[0]) - - def test_basic_functions(self): - rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}']) - srdd = self.sqlCtx.jsonRDD(rdd) - srdd.count() - srdd.collect() - srdd.schemaString() - srdd.schema() - - # cache and checkpoint - self.assertFalse(srdd.is_cached) - srdd.persist() - srdd.unpersist() - srdd.cache() - self.assertTrue(srdd.is_cached) - self.assertFalse(srdd.isCheckpointed()) - self.assertEqual(None, srdd.getCheckpointFile()) - - srdd = srdd.coalesce(2, True) - srdd = srdd.repartition(3) - srdd = srdd.distinct() - srdd.intersection(srdd) - self.assertEqual(2, srdd.count()) - - srdd.registerTempTable("temp") - srdd = self.sqlCtx.sql("select foo from temp") - srdd.count() - srdd.collect() + self.sc.show_profiles() + self.assertEqual("Custom formatting", profiler.result) - def test_distinct(self): - rdd = self.sc.parallelize(['{"a": 1}', '{"b": 2}', '{"c": 3}']*10, 10) - srdd = self.sqlCtx.jsonRDD(rdd) - self.assertEquals(srdd.getNumPartitions(), 10) - self.assertEquals(srdd.distinct().count(), 3) - result = srdd.distinct(5) - self.assertEquals(result.getNumPartitions(), 5) - self.assertEquals(result.count(), 3) + def do_computation(self): + def heavy_foo(x): + for i in range(1 << 20): + x = 1 - def test_apply_schema_to_row(self): - srdd = self.sqlCtx.jsonRDD(self.sc.parallelize(["""{"a":2}"""])) - srdd2 = self.sqlCtx.applySchema(srdd.map(lambda x: x), srdd.schema()) - self.assertEqual(srdd.collect(), srdd2.collect()) - - rdd = self.sc.parallelize(range(10)).map(lambda x: Row(a=x)) - srdd3 = self.sqlCtx.applySchema(rdd, srdd.schema()) - self.assertEqual(10, srdd3.count()) - - def test_serialize_nested_array_and_map(self): - d = [Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")})] - rdd = self.sc.parallelize(d) - srdd = self.sqlCtx.inferSchema(rdd) - row = srdd.first() - self.assertEqual(1, len(row.l)) - self.assertEqual(1, row.l[0].a) - self.assertEqual("2", row.d["key"].d) - - l = srdd.map(lambda x: x.l).first() - self.assertEqual(1, len(l)) - self.assertEqual('s', l[0].b) - - d = srdd.map(lambda x: x.d).first() - self.assertEqual(1, len(d)) - self.assertEqual(1.0, d["key"].c) - - row = srdd.map(lambda x: x.d["key"]).first() - self.assertEqual(1.0, row.c) - self.assertEqual("2", row.d) - - def test_infer_schema(self): - d = [Row(l=[], d={}), - Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")}, s="")] - rdd = self.sc.parallelize(d) - srdd = self.sqlCtx.inferSchema(rdd) - self.assertEqual([], srdd.map(lambda r: r.l).first()) - self.assertEqual([None, ""], srdd.map(lambda r: r.s).collect()) - srdd.registerTempTable("test") - result = self.sqlCtx.sql("SELECT l[0].a from test where d['key'].d = '2'") - self.assertEqual(1, result.first()[0]) - - srdd2 = self.sqlCtx.inferSchema(rdd, 1.0) - self.assertEqual(srdd.schema(), srdd2.schema()) - self.assertEqual({}, srdd2.map(lambda r: r.d).first()) - self.assertEqual([None, ""], srdd2.map(lambda r: r.s).collect()) - srdd2.registerTempTable("test2") - result = self.sqlCtx.sql("SELECT l[0].a from test2 where d['key'].d = '2'") - self.assertEqual(1, result.first()[0]) - - def test_struct_in_map(self): - d = [Row(m={Row(i=1): Row(s="")})] - rdd = self.sc.parallelize(d) - srdd = self.sqlCtx.inferSchema(rdd) - k, v = srdd.first().m.items()[0] - self.assertEqual(1, k.i) - self.assertEqual("", v.s) - - def test_convert_row_to_dict(self): - row = Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")}) - self.assertEqual(1, row.asDict()['l'][0].a) - rdd = self.sc.parallelize([row]) - srdd = self.sqlCtx.inferSchema(rdd) - srdd.registerTempTable("test") - row = self.sqlCtx.sql("select l, d from test").first() - self.assertEqual(1, row.asDict()["l"][0].a) - self.assertEqual(1.0, row.asDict()['d']['key'].c) - - def test_infer_schema_with_udt(self): - from pyspark.tests import ExamplePoint, ExamplePointUDT - row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) - rdd = self.sc.parallelize([row]) - srdd = self.sqlCtx.inferSchema(rdd) - schema = srdd.schema() - field = [f for f in schema.fields if f.name == "point"][0] - self.assertEqual(type(field.dataType), ExamplePointUDT) - srdd.registerTempTable("labeled_point") - point = self.sqlCtx.sql("SELECT point FROM labeled_point").first().point - self.assertEqual(point, ExamplePoint(1.0, 2.0)) - - def test_apply_schema_with_udt(self): - from pyspark.tests import ExamplePoint, ExamplePointUDT - row = (1.0, ExamplePoint(1.0, 2.0)) - rdd = self.sc.parallelize([row]) - schema = StructType([StructField("label", DoubleType(), False), - StructField("point", ExamplePointUDT(), False)]) - srdd = self.sqlCtx.applySchema(rdd, schema) - point = srdd.first().point - self.assertEquals(point, ExamplePoint(1.0, 2.0)) - - def test_parquet_with_udt(self): - from pyspark.tests import ExamplePoint - row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) - rdd = self.sc.parallelize([row]) - srdd0 = self.sqlCtx.inferSchema(rdd) - output_dir = os.path.join(self.tempdir.name, "labeled_point") - srdd0.saveAsParquetFile(output_dir) - srdd1 = self.sqlCtx.parquetFile(output_dir) - point = srdd1.first().point - self.assertEquals(point, ExamplePoint(1.0, 2.0)) + rdd = self.sc.parallelize(range(100)) + rdd.foreach(heavy_foo) class InputFormatTests(ReusedPySparkTestCase): @@ -1587,31 +1446,59 @@ def setUp(self): def tearDown(self): shutil.rmtree(self.programDir) - def createTempFile(self, name, content): + def createTempFile(self, name, content, dir=None): """ Create a temp file with the given name and content and return its path. Strips leading spaces from content up to the first '|' in each line. """ pattern = re.compile(r'^ *\|', re.MULTILINE) content = re.sub(pattern, '', content.strip()) - path = os.path.join(self.programDir, name) + if dir is None: + path = os.path.join(self.programDir, name) + else: + os.makedirs(os.path.join(self.programDir, dir)) + path = os.path.join(self.programDir, dir, name) with open(path, "w") as f: f.write(content) return path - def createFileInZip(self, name, content): + def createFileInZip(self, name, content, ext=".zip", dir=None, zip_name=None): """ Create a zip archive containing a file with the given content and return its path. Strips leading spaces from content up to the first '|' in each line. """ pattern = re.compile(r'^ *\|', re.MULTILINE) content = re.sub(pattern, '', content.strip()) - path = os.path.join(self.programDir, name + ".zip") + if dir is None: + path = os.path.join(self.programDir, name + ext) + else: + path = os.path.join(self.programDir, dir, zip_name + ext) zip = zipfile.ZipFile(path, 'w') zip.writestr(name, content) zip.close() return path + def create_spark_package(self, artifact_name): + group_id, artifact_id, version = artifact_name.split(":") + self.createTempFile("%s-%s.pom" % (artifact_id, version), (""" + | + | + | 4.0.0 + | %s + | %s + | %s + | + """ % (group_id, artifact_id, version)).lstrip(), + os.path.join(group_id, artifact_id, version)) + self.createFileInZip("%s.py" % artifact_id, """ + |def myfunc(x): + | return x + 1 + """, ".jar", os.path.join(group_id, artifact_id, version), + "%s-%s" % (artifact_id, version)) + def test_single_script(self): """Submit and test a single script file""" script = self.createTempFile("test.py", """ @@ -1680,6 +1567,39 @@ def test_module_dependency_on_cluster(self): self.assertEqual(0, proc.returncode) self.assertIn("[2, 3, 4]", out) + def test_package_dependency(self): + """Submit and test a script with a dependency on a Spark Package""" + script = self.createTempFile("test.py", """ + |from pyspark import SparkContext + |from mylib import myfunc + | + |sc = SparkContext() + |print sc.parallelize([1, 2, 3]).map(myfunc).collect() + """) + self.create_spark_package("a:mylib:0.1") + proc = subprocess.Popen([self.sparkSubmit, "--packages", "a:mylib:0.1", "--repositories", + "file:" + self.programDir, script], stdout=subprocess.PIPE) + out, err = proc.communicate() + self.assertEqual(0, proc.returncode) + self.assertIn("[2, 3, 4]", out) + + def test_package_dependency_on_cluster(self): + """Submit and test a script with a dependency on a Spark Package on a cluster""" + script = self.createTempFile("test.py", """ + |from pyspark import SparkContext + |from mylib import myfunc + | + |sc = SparkContext() + |print sc.parallelize([1, 2, 3]).map(myfunc).collect() + """) + self.create_spark_package("a:mylib:0.1") + proc = subprocess.Popen([self.sparkSubmit, "--packages", "a:mylib:0.1", "--repositories", + "file:" + self.programDir, "--master", + "local-cluster[1,1,512]", script], stdout=subprocess.PIPE) + out, err = proc.communicate() + self.assertEqual(0, proc.returncode) + self.assertIn("[2, 3, 4]", out) + def test_single_script_on_cluster(self): """Submit and test a single script on a cluster""" script = self.createTempFile("test.py", """ @@ -1733,6 +1653,37 @@ def test_with_stop(self): sc.stop() self.assertEqual(SparkContext._active_spark_context, None) + def test_progress_api(self): + with SparkContext() as sc: + sc.setJobGroup('test_progress_api', '', True) + + rdd = sc.parallelize(range(10)).map(lambda x: time.sleep(100)) + t = threading.Thread(target=rdd.collect) + t.daemon = True + t.start() + # wait for scheduler to start + time.sleep(1) + + tracker = sc.statusTracker() + jobIds = tracker.getJobIdsForGroup('test_progress_api') + self.assertEqual(1, len(jobIds)) + job = tracker.getJobInfo(jobIds[0]) + self.assertEqual(1, len(job.stageIds)) + stage = tracker.getStageInfo(job.stageIds[0]) + self.assertEqual(rdd.getNumPartitions(), stage.numTasks) + + sc.cancelAllJobs() + t.join() + # wait for event listener to update the status + time.sleep(1) + + job = tracker.getJobInfo(jobIds[0]) + self.assertEqual('FAILED', job.status) + self.assertEqual([], tracker.getActiveJobsIds()) + self.assertEqual([], tracker.getActiveStageIds()) + + sc.stop() + @unittest.skipIf(not _have_scipy, "SciPy not installed") class SciPyTests(PySparkTestCase): diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 7e5343c973dc5..8a93c320ec5d3 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -23,8 +23,6 @@ import time import socket import traceback -import cProfile -import pstats from pyspark.accumulators import _accumulatorRegistry from pyspark.broadcast import Broadcast, _broadcastRegistry @@ -90,19 +88,15 @@ def main(infile, outfile): command = pickleSer._read_with_length(infile) if isinstance(command, Broadcast): command = pickleSer.loads(command.value) - (func, stats, deserializer, serializer) = command + (func, profiler, deserializer, serializer) = command init_time = time.time() def process(): iterator = deserializer.load_stream(infile) serializer.dump_stream(func(split_index, iterator), outfile) - if stats: - p = cProfile.Profile() - p.runcall(process) - st = pstats.Stats(p) - st.stream = None # make it picklable - stats.add(st.strip_dirs()) + if profiler: + profiler.profile(process) else: process() except Exception: diff --git a/python/run-tests b/python/run-tests index 9ee19ed6e6b26..a2c2f37a54eda 100755 --- a/python/run-tests +++ b/python/run-tests @@ -35,7 +35,7 @@ rm -rf metastore warehouse function run_test() { echo "Running test: $1" | tee -a $LOG_FILE - SPARK_TESTING=1 time "$FWDIR"/bin/pyspark $1 2>&1 | tee -a $LOG_FILE + SPARK_TESTING=1 time "$FWDIR"/bin/pyspark $1 > $LOG_FILE 2>&1 FAILED=$((PIPESTATUS[0]||$FAILED)) @@ -57,13 +57,18 @@ function run_core_tests() { PYSPARK_DOC_TEST=1 run_test "pyspark/broadcast.py" PYSPARK_DOC_TEST=1 run_test "pyspark/accumulators.py" run_test "pyspark/serializers.py" + run_test "pyspark/profiler.py" run_test "pyspark/shuffle.py" run_test "pyspark/tests.py" } function run_sql_tests() { echo "Run sql tests ..." - run_test "pyspark/sql.py" + run_test "pyspark/sql/types.py" + run_test "pyspark/sql/context.py" + run_test "pyspark/sql/dataframe.py" + run_test "pyspark/sql/functions.py" + run_test "pyspark/sql/tests.py" } function run_mllib_tests() { @@ -75,12 +80,19 @@ function run_mllib_tests() { run_test "pyspark/mllib/rand.py" run_test "pyspark/mllib/recommendation.py" run_test "pyspark/mllib/regression.py" - run_test "pyspark/mllib/stat.py" + run_test "pyspark/mllib/stat/_statistics.py" run_test "pyspark/mllib/tree.py" run_test "pyspark/mllib/util.py" run_test "pyspark/mllib/tests.py" } +function run_ml_tests() { + echo "Run ml tests ..." + run_test "pyspark/ml/feature.py" + run_test "pyspark/ml/classification.py" + run_test "pyspark/ml/tests.py" +} + function run_streaming_tests() { echo "Run streaming tests ..." run_test "pyspark/streaming/util.py" @@ -102,6 +114,7 @@ $PYSPARK_PYTHON --version run_core_tests run_sql_tests run_mllib_tests +run_ml_tests run_streaming_tests # Try to test with PyPy diff --git a/repl/pom.xml b/repl/pom.xml index 0bc8bccf90a6d..b883344bf0ceb 100644 --- a/repl/pom.xml +++ b/repl/pom.xml @@ -33,8 +33,6 @@ repl - /usr/share/spark - root scala-2.10/src/main/scala scala-2.10/src/test/scala @@ -66,7 +64,6 @@ org.apache.spark spark-sql_${scala.binary.version} ${project.version} - test org.scala-lang @@ -87,18 +84,35 @@ scalacheck_${scala.binary.version} test + + + + org.eclipse.jetty + jetty-server + + + org.eclipse.jetty + jetty-plus + + + org.eclipse.jetty + jetty-util + + + org.eclipse.jetty + jetty-http + + + + + org.scala-lang + scala-library + target/scala-${scala.binary.version}/classes target/scala-${scala.binary.version}/test-classes - - org.apache.maven.plugins - maven-deploy-plugin - - true - - org.codehaus.mojo diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkCommandLine.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkCommandLine.scala index 05816941b54b3..6480e2d24e044 100644 --- a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkCommandLine.scala +++ b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkCommandLine.scala @@ -19,14 +19,21 @@ package org.apache.spark.repl import scala.tools.nsc.{Settings, CompilerCommand} import scala.Predef._ +import org.apache.spark.annotation.DeveloperApi /** * Command class enabling Spark-specific command line options (provided by * org.apache.spark.repl.SparkRunnerSettings). + * + * @example new SparkCommandLine(Nil).settings + * + * @param args The list of command line arguments + * @param settings The underlying settings to associate with this set of + * command-line options */ +@DeveloperApi class SparkCommandLine(args: List[String], override val settings: Settings) extends CompilerCommand(args, settings) { - def this(args: List[String], error: String => Unit) { this(args, new SparkRunnerSettings(error)) } diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkExprTyper.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkExprTyper.scala index f8432c8af6ed2..5fb378112ef92 100644 --- a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkExprTyper.scala +++ b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkExprTyper.scala @@ -15,7 +15,7 @@ import scala.tools.nsc.ast.parser.Tokens.EOF import org.apache.spark.Logging -trait SparkExprTyper extends Logging { +private[repl] trait SparkExprTyper extends Logging { val repl: SparkIMain import repl._ diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkHelper.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkHelper.scala index 5340951d91331..955be17a73b85 100644 --- a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkHelper.scala +++ b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkHelper.scala @@ -17,6 +17,23 @@ package scala.tools.nsc +import org.apache.spark.annotation.DeveloperApi + +// NOTE: Forced to be public (and in scala.tools.nsc package) to access the +// settings "explicitParentLoader" method + +/** + * Provides exposure for the explicitParentLoader method on settings instances. + */ +@DeveloperApi object SparkHelper { + /** + * Retrieves the explicit parent loader for the provided settings. + * + * @param settings The settings whose explicit parent loader to retrieve + * + * @return The Optional classloader representing the explicit parent loader + */ + @DeveloperApi def explicitParentLoader(settings: Settings) = settings.explicitParentLoader } diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoop.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoop.scala index e56b74edba88c..8dc0e0c965923 100644 --- a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoop.scala +++ b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoop.scala @@ -10,6 +10,8 @@ package org.apache.spark.repl import java.net.URL +import org.apache.spark.annotation.DeveloperApi + import scala.reflect.io.AbstractFile import scala.tools.nsc._ import scala.tools.nsc.backend.JavaPlatform @@ -43,6 +45,7 @@ import scala.reflect.api.{Mirror, TypeCreator, Universe => ApiUniverse} import org.apache.spark.Logging import org.apache.spark.SparkConf import org.apache.spark.SparkContext +import org.apache.spark.sql.SQLContext import org.apache.spark.util.Utils /** The Scala interactive shell. It provides a read-eval-print loop @@ -57,20 +60,22 @@ import org.apache.spark.util.Utils * @author Lex Spoon * @version 1.2 */ -class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter, - val master: Option[String]) - extends AnyRef - with LoopCommands - with SparkILoopInit - with Logging -{ +@DeveloperApi +class SparkILoop( + private val in0: Option[BufferedReader], + protected val out: JPrintWriter, + val master: Option[String] +) extends AnyRef with LoopCommands with SparkILoopInit with Logging { def this(in0: BufferedReader, out: JPrintWriter, master: String) = this(Some(in0), out, Some(master)) def this(in0: BufferedReader, out: JPrintWriter) = this(Some(in0), out, None) def this() = this(None, new JPrintWriter(Console.out, true), None) - var in: InteractiveReader = _ // the input stream from which commands come - var settings: Settings = _ - var intp: SparkIMain = _ + private var in: InteractiveReader = _ // the input stream from which commands come + + // NOTE: Exposed in package for testing + private[repl] var settings: Settings = _ + + private[repl] var intp: SparkIMain = _ @deprecated("Use `intp` instead.", "2.9.0") def interpreter = intp @deprecated("Use `intp` instead.", "2.9.0") def interpreter_= (i: SparkIMain): Unit = intp = i @@ -123,52 +128,55 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter, } } + // NOTE: Must be public for visibility + @DeveloperApi var sparkContext: SparkContext = _ + var sqlContext: SQLContext = _ override def echoCommandMessage(msg: String) { intp.reporter printMessage msg } // def isAsync = !settings.Yreplsync.value - def isAsync = false + private[repl] def isAsync = false // lazy val power = new Power(intp, new StdReplVals(this))(tagOfStdReplVals, classTag[StdReplVals]) - def history = in.history + private def history = in.history /** The context class loader at the time this object was created */ protected val originalClassLoader = Utils.getContextOrSparkClassLoader // classpath entries added via :cp - var addedClasspath: String = "" + private var addedClasspath: String = "" /** A reverse list of commands to replay if the user requests a :replay */ - var replayCommandStack: List[String] = Nil + private var replayCommandStack: List[String] = Nil /** A list of commands to replay if the user requests a :replay */ - def replayCommands = replayCommandStack.reverse + private def replayCommands = replayCommandStack.reverse /** Record a command for replay should the user request a :replay */ - def addReplay(cmd: String) = replayCommandStack ::= cmd + private def addReplay(cmd: String) = replayCommandStack ::= cmd - def savingReplayStack[T](body: => T): T = { + private def savingReplayStack[T](body: => T): T = { val saved = replayCommandStack try body finally replayCommandStack = saved } - def savingReader[T](body: => T): T = { + private def savingReader[T](body: => T): T = { val saved = in try body finally in = saved } - def sparkCleanUp(){ + private def sparkCleanUp(){ echo("Stopping spark context.") intp.beQuietDuring { command("sc.stop()") } } /** Close the interpreter and set the var to null. */ - def closeInterpreter() { + private def closeInterpreter() { if (intp ne null) { sparkCleanUp() intp.close() @@ -179,14 +187,16 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter, class SparkILoopInterpreter extends SparkIMain(settings, out) { outer => - override lazy val formatting = new Formatting { + override private[repl] lazy val formatting = new Formatting { def prompt = SparkILoop.this.prompt } override protected def parentClassLoader = SparkHelper.explicitParentLoader(settings).getOrElse(classOf[SparkILoop].getClassLoader) } - /** Create a new interpreter. */ - def createInterpreter() { + /** + * Constructs a new interpreter. + */ + protected def createInterpreter() { require(settings != null) if (addedClasspath != "") settings.classpath.append(addedClasspath) @@ -207,7 +217,7 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter, } /** print a friendly help message */ - def helpCommand(line: String): Result = { + private def helpCommand(line: String): Result = { if (line == "") helpSummary() else uniqueCommand(line) match { case Some(lc) => echo("\n" + lc.longHelp) @@ -258,7 +268,7 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter, } /** Show the history */ - lazy val historyCommand = new LoopCommand("history", "show the history (optional num is commands to show)") { + private lazy val historyCommand = new LoopCommand("history", "show the history (optional num is commands to show)") { override def usage = "[num]" def defaultLines = 20 @@ -279,21 +289,21 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter, // When you know you are most likely breaking into the middle // of a line being typed. This softens the blow. - protected def echoAndRefresh(msg: String) = { + private[repl] def echoAndRefresh(msg: String) = { echo("\n" + msg) in.redrawLine() } - protected def echo(msg: String) = { + private[repl] def echo(msg: String) = { out println msg out.flush() } - protected def echoNoNL(msg: String) = { + private def echoNoNL(msg: String) = { out print msg out.flush() } /** Search the history */ - def searchHistory(_cmdline: String) { + private def searchHistory(_cmdline: String) { val cmdline = _cmdline.toLowerCase val offset = history.index - history.size + 1 @@ -302,14 +312,27 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter, } private var currentPrompt = Properties.shellPromptString + + /** + * Sets the prompt string used by the REPL. + * + * @param prompt The new prompt string + */ + @DeveloperApi def setPrompt(prompt: String) = currentPrompt = prompt - /** Prompt to print when awaiting input */ + + /** + * Represents the current prompt string used by the REPL. + * + * @return The current prompt string + */ + @DeveloperApi def prompt = currentPrompt import LoopCommand.{ cmd, nullary } /** Standard commands */ - lazy val standardCommands = List( + private lazy val standardCommands = List( cmd("cp", "", "add a jar or directory to the classpath", addClasspath), cmd("help", "[command]", "print this summary or command-specific help", helpCommand), historyCommand, @@ -333,7 +356,7 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter, ) /** Power user commands */ - lazy val powerCommands: List[LoopCommand] = List( + private lazy val powerCommands: List[LoopCommand] = List( // cmd("phase", "", "set the implicit phase for power commands", phaseCommand) ) @@ -459,7 +482,7 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter, } } - protected def newJavap() = new JavapClass(addToolsJarToLoader(), new SparkIMain.ReplStrippingWriter(intp)) { + private def newJavap() = new JavapClass(addToolsJarToLoader(), new SparkIMain.ReplStrippingWriter(intp)) { override def tryClass(path: String): Array[Byte] = { val hd :: rest = path split '.' toList; // If there are dots in the name, the first segment is the @@ -581,7 +604,12 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter, // } // } - /** Available commands */ + /** + * Provides a list of available commands. + * + * @return The list of commands + */ + @DeveloperApi def commands: List[LoopCommand] = standardCommands /*++ ( if (isReplPower) powerCommands else Nil )*/ @@ -613,7 +641,7 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter, * command() for each line of input, and stops when * command() returns false. */ - def loop() { + private def loop() { def readOneLine() = { out.flush() in readLine prompt @@ -642,7 +670,7 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter, } /** interpret all lines from a specified file */ - def interpretAllFrom(file: File) { + private def interpretAllFrom(file: File) { savingReader { savingReplayStack { file applyReader { reader => @@ -655,7 +683,7 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter, } /** create a new interpreter and replay the given commands */ - def replay() { + private def replay() { reset() if (replayCommandStack.isEmpty) echo("Nothing to replay.") @@ -665,7 +693,7 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter, echo("") } } - def resetCommand() { + private def resetCommand() { echo("Resetting repl state.") if (replayCommandStack.nonEmpty) { echo("Forgetting this session history:\n") @@ -681,13 +709,13 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter, reset() } - def reset() { + private def reset() { intp.reset() // unleashAndSetPhase() } /** fork a shell and run a command */ - lazy val shCommand = new LoopCommand("sh", "run a shell command (result is implicitly => List[String])") { + private lazy val shCommand = new LoopCommand("sh", "run a shell command (result is implicitly => List[String])") { override def usage = "" def apply(line: String): Result = line match { case "" => showUsage() @@ -698,14 +726,14 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter, } } - def withFile(filename: String)(action: File => Unit) { + private def withFile(filename: String)(action: File => Unit) { val f = File(filename) if (f.exists) action(f) else echo("That file does not exist") } - def loadCommand(arg: String) = { + private def loadCommand(arg: String) = { var shouldReplay: Option[String] = None withFile(arg)(f => { interpretAllFrom(f) @@ -714,7 +742,7 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter, Result(true, shouldReplay) } - def addAllClasspath(args: Seq[String]): Unit = { + private def addAllClasspath(args: Seq[String]): Unit = { var added = false var totalClasspath = "" for (arg <- args) { @@ -729,7 +757,7 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter, } } - def addClasspath(arg: String): Unit = { + private def addClasspath(arg: String): Unit = { val f = File(arg).normalize if (f.exists) { addedClasspath = ClassPath.join(addedClasspath, f.path) @@ -741,12 +769,12 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter, } - def powerCmd(): Result = { + private def powerCmd(): Result = { if (isReplPower) "Already in power mode." else enablePowerMode(false) } - def enablePowerMode(isDuringInit: Boolean) = { + private[repl] def enablePowerMode(isDuringInit: Boolean) = { // replProps.power setValue true // unleashAndSetPhase() // asyncEcho(isDuringInit, power.banner) @@ -759,12 +787,12 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter, // } // } - def asyncEcho(async: Boolean, msg: => String) { + private def asyncEcho(async: Boolean, msg: => String) { if (async) asyncMessage(msg) else echo(msg) } - def verbosity() = { + private def verbosity() = { // val old = intp.printResults // intp.printResults = !old // echo("Switched " + (if (old) "off" else "on") + " result printing.") @@ -773,7 +801,7 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter, /** Run one command submitted by the user. Two values are returned: * (1) whether to keep running, (2) the line to record for replay, * if any. */ - def command(line: String): Result = { + private[repl] def command(line: String): Result = { if (line startsWith ":") { val cmd = line.tail takeWhile (x => !x.isWhitespace) uniqueCommand(cmd) match { @@ -789,7 +817,7 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter, Iterator continually in.readLine("") takeWhile (x => x != null && cond(x)) } - def pasteCommand(): Result = { + private def pasteCommand(): Result = { echo("// Entering paste mode (ctrl-D to finish)\n") val code = readWhile(_ => true) mkString "\n" echo("\n// Exiting paste mode, now interpreting.\n") @@ -820,7 +848,7 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter, * read, go ahead and interpret it. Return the full string * to be recorded for replay, if any. */ - def interpretStartingWith(code: String): Option[String] = { + private def interpretStartingWith(code: String): Option[String] = { // signal completion non-completion input has been received in.completion.resetVerbosity() @@ -874,7 +902,7 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter, } // runs :load `file` on any files passed via -i - def loadFiles(settings: Settings) = settings match { + private def loadFiles(settings: Settings) = settings match { case settings: SparkRunnerSettings => for (filename <- settings.loadfiles.value) { val cmd = ":load " + filename @@ -889,7 +917,7 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter, * unless settings or properties are such that it should start * with SimpleReader. */ - def chooseReader(settings: Settings): InteractiveReader = { + private def chooseReader(settings: Settings): InteractiveReader = { if (settings.Xnojline.value || Properties.isEmacsShell) SimpleReader() else try new SparkJLineReader( @@ -903,8 +931,8 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter, } } - val u: scala.reflect.runtime.universe.type = scala.reflect.runtime.universe - val m = u.runtimeMirror(Utils.getSparkClassLoader) + private val u: scala.reflect.runtime.universe.type = scala.reflect.runtime.universe + private val m = u.runtimeMirror(Utils.getSparkClassLoader) private def tagOfStaticClass[T: ClassTag]: u.TypeTag[T] = u.TypeTag[T]( m, @@ -913,7 +941,7 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter, m.staticClass(classTag[T].runtimeClass.getName).toTypeConstructor.asInstanceOf[U # Type] }) - def process(settings: Settings): Boolean = savingContextLoader { + private def process(settings: Settings): Boolean = savingContextLoader { if (getMaster() == "yarn-client") System.setProperty("SPARK_YARN_MODE", "true") this.settings = settings @@ -972,6 +1000,8 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter, true } + // NOTE: Must be public for visibility + @DeveloperApi def createSparkContext(): SparkContext = { val execUri = System.getenv("SPARK_EXECUTOR_URI") val jars = SparkILoop.getAddedJars @@ -979,7 +1009,7 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter, .setMaster(getMaster()) .setAppName("Spark shell") .setJars(jars) - .set("spark.repl.class.uri", intp.classServer.uri) + .set("spark.repl.class.uri", intp.classServerUri) if (execUri != null) { conf.set("spark.executor.uri", execUri) } @@ -988,6 +1018,23 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter, sparkContext } + @DeveloperApi + def createSQLContext(): SQLContext = { + val name = "org.apache.spark.sql.hive.HiveContext" + val loader = Utils.getContextOrSparkClassLoader + try { + sqlContext = loader.loadClass(name).getConstructor(classOf[SparkContext]) + .newInstance(sparkContext).asInstanceOf[SQLContext] + logInfo("Created sql context (with Hive support)..") + } + catch { + case cnf: java.lang.ClassNotFoundException => + sqlContext = new SQLContext(sparkContext) + logInfo("Created sql context..") + } + sqlContext + } + private def getMaster(): String = { val master = this.master match { case Some(m) => m @@ -1014,18 +1061,19 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter, } @deprecated("Use `process` instead", "2.9.0") - def main(settings: Settings): Unit = process(settings) + private def main(settings: Settings): Unit = process(settings) } -object SparkILoop { +object SparkILoop extends Logging { implicit def loopToInterpreter(repl: SparkILoop): SparkIMain = repl.intp private def echo(msg: String) = Console println msg def getAddedJars: Array[String] = { val envJars = sys.env.get("ADD_JARS") - val propJars = sys.props.get("spark.jars").flatMap { p => - if (p == "") None else Some(p) + if (envJars.isDefined) { + logWarning("ADD_JARS environment variable is deprecated, use --jar spark submit argument instead") } + val propJars = sys.props.get("spark.jars").flatMap { p => if (p == "") None else Some(p) } val jars = propJars.orElse(envJars).getOrElse("") Utils.resolveURIs(jars).split(",").filter(_.nonEmpty) } @@ -1033,7 +1081,7 @@ object SparkILoop { // Designed primarily for use by test code: take a String with a // bunch of code, and prints out a transcript of what it would look // like if you'd just typed it into the repl. - def runForTranscript(code: String, settings: Settings): String = { + private[repl] def runForTranscript(code: String, settings: Settings): String = { import java.io.{ BufferedReader, StringReader, OutputStreamWriter } stringFromStream { ostream => @@ -1071,7 +1119,7 @@ object SparkILoop { /** Creates an interpreter loop with default settings and feeds * the given code to it as input. */ - def run(code: String, sets: Settings = new Settings): String = { + private[repl] def run(code: String, sets: Settings = new Settings): String = { import java.io.{ BufferedReader, StringReader, OutputStreamWriter } stringFromStream { ostream => @@ -1087,5 +1135,5 @@ object SparkILoop { } } } - def run(lines: List[String]): String = run(lines map (_ + "\n") mkString) + private[repl] def run(lines: List[String]): String = run(lines map (_ + "\n") mkString) } diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala index da4286c5e4874..05faef8786d2c 100644 --- a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala +++ b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala @@ -19,7 +19,7 @@ import org.apache.spark.SPARK_VERSION /** * Machinery for the asynchronous initialization of the repl. */ -trait SparkILoopInit { +private[repl] trait SparkILoopInit { self: SparkILoop => /** Print a welcome message */ @@ -127,7 +127,17 @@ trait SparkILoopInit { _sc } """) + command(""" + @transient val sqlContext = { + val _sqlContext = org.apache.spark.repl.Main.interp.createSQLContext() + println("SQL context available as sqlContext.") + _sqlContext + } + """) command("import org.apache.spark.SparkContext._") + command("import sqlContext.implicits._") + command("import sqlContext.sql") + command("import org.apache.spark.sql.functions._") } } diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala index b646f0b6f0868..35fb625645022 100644 --- a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala +++ b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala @@ -39,6 +39,7 @@ import scala.util.control.ControlThrowable import org.apache.spark.{Logging, HttpServer, SecurityManager, SparkConf} import org.apache.spark.util.Utils +import org.apache.spark.annotation.DeveloperApi // /** directory to save .class files to */ // private class ReplVirtualDirectory(out: JPrintWriter) extends VirtualDirectory("((memory))", None) { @@ -84,17 +85,18 @@ import org.apache.spark.util.Utils * @author Moez A. Abdel-Gawad * @author Lex Spoon */ + @DeveloperApi class SparkIMain( initialSettings: Settings, val out: JPrintWriter, propagateExceptions: Boolean = false) extends SparkImports with Logging { imain => - val conf = new SparkConf() + private val conf = new SparkConf() - val SPARK_DEBUG_REPL: Boolean = (System.getenv("SPARK_DEBUG_REPL") == "1") + private val SPARK_DEBUG_REPL: Boolean = (System.getenv("SPARK_DEBUG_REPL") == "1") /** Local directory to save .class files too */ - lazy val outputDir = { + private lazy val outputDir = { val tmp = System.getProperty("java.io.tmpdir") val rootDir = conf.get("spark.repl.classdir", tmp) Utils.createTempDir(rootDir) @@ -103,13 +105,20 @@ import org.apache.spark.util.Utils echo("Output directory: " + outputDir) } - val virtualDirectory = new PlainFile(outputDir) // "directory" for classfiles + /** + * Returns the path to the output directory containing all generated + * class files that will be served by the REPL class server. + */ + @DeveloperApi + lazy val getClassOutputDirectory = outputDir + + private val virtualDirectory = new PlainFile(outputDir) // "directory" for classfiles /** Jetty server that will serve our classes to worker nodes */ - val classServerPort = conf.getInt("spark.replClassServer.port", 0) - val classServer = new HttpServer(conf, outputDir, new SecurityManager(conf), classServerPort, "HTTP class server") + private val classServerPort = conf.getInt("spark.replClassServer.port", 0) + private val classServer = new HttpServer(conf, outputDir, new SecurityManager(conf), classServerPort, "HTTP class server") private var currentSettings: Settings = initialSettings - var printResults = true // whether to print result lines - var totalSilence = false // whether to print anything + private var printResults = true // whether to print result lines + private var totalSilence = false // whether to print anything private var _initializeComplete = false // compiler is initialized private var _isInitialized: Future[Boolean] = null // set up initialization future private var bindExceptions = true // whether to bind the lastException variable @@ -123,6 +132,14 @@ import org.apache.spark.util.Utils echo("Class server started, URI = " + classServer.uri) } + /** + * URI of the class server used to feed REPL compiled classes. + * + * @return The string representing the class server uri + */ + @DeveloperApi + def classServerUri = classServer.uri + /** We're going to go to some trouble to initialize the compiler asynchronously. * It's critical that nothing call into it until it's been initialized or we will * run into unrecoverable issues, but the perceived repl startup time goes @@ -141,17 +158,18 @@ import org.apache.spark.util.Utils () => { counter += 1 ; counter } } - def compilerClasspath: Seq[URL] = ( + private def compilerClasspath: Seq[URL] = ( if (isInitializeComplete) global.classPath.asURLs else new PathResolver(settings).result.asURLs // the compiler's classpath ) - def settings = currentSettings - def mostRecentLine = prevRequestList match { + // NOTE: Exposed to repl package since accessed indirectly from SparkIMain + private[repl] def settings = currentSettings + private def mostRecentLine = prevRequestList match { case Nil => "" case req :: _ => req.originalLine } // Run the code body with the given boolean settings flipped to true. - def withoutWarnings[T](body: => T): T = beQuietDuring { + private def withoutWarnings[T](body: => T): T = beQuietDuring { val saved = settings.nowarn.value if (!saved) settings.nowarn.value = true @@ -164,16 +182,28 @@ import org.apache.spark.util.Utils def this(settings: Settings) = this(settings, new NewLinePrintWriter(new ConsoleWriter, true)) def this() = this(new Settings()) - lazy val repllog: Logger = new Logger { + private lazy val repllog: Logger = new Logger { val out: JPrintWriter = imain.out val isInfo: Boolean = BooleanProp keyExists "scala.repl.info" val isDebug: Boolean = BooleanProp keyExists "scala.repl.debug" val isTrace: Boolean = BooleanProp keyExists "scala.repl.trace" } - lazy val formatting: Formatting = new Formatting { + private[repl] lazy val formatting: Formatting = new Formatting { val prompt = Properties.shellPromptString } - lazy val reporter: ConsoleReporter = new SparkIMain.ReplReporter(this) + + // NOTE: Exposed to repl package since used by SparkExprTyper and SparkILoop + private[repl] lazy val reporter: ConsoleReporter = new SparkIMain.ReplReporter(this) + + /** + * Determines if errors were reported (typically during compilation). + * + * @note This is not for runtime errors + * + * @return True if had errors, otherwise false + */ + @DeveloperApi + def isReportingErrors = reporter.hasErrors import formatting._ import reporter.{ printMessage, withoutTruncating } @@ -193,7 +223,8 @@ import org.apache.spark.util.Utils private def tquoted(s: String) = "\"\"\"" + s + "\"\"\"" // argument is a thunk to execute after init is done - def initialize(postInitSignal: => Unit) { + // NOTE: Exposed to repl package since used by SparkILoop + private[repl] def initialize(postInitSignal: => Unit) { synchronized { if (_isInitialized == null) { _isInitialized = io.spawn { @@ -203,15 +234,27 @@ import org.apache.spark.util.Utils } } } + + /** + * Initializes the underlying compiler/interpreter in a blocking fashion. + * + * @note Must be executed before using SparkIMain! + */ + @DeveloperApi def initializeSynchronous(): Unit = { if (!isInitializeComplete) { _initialize() assert(global != null, global) } } - def isInitializeComplete = _initializeComplete + private def isInitializeComplete = _initializeComplete /** the public, go through the future compiler */ + + /** + * The underlying compiler used to generate ASTs and execute code. + */ + @DeveloperApi lazy val global: Global = { if (isInitializeComplete) _compiler else { @@ -226,13 +269,13 @@ import org.apache.spark.util.Utils } } @deprecated("Use `global` for access to the compiler instance.", "2.9.0") - lazy val compiler: global.type = global + private lazy val compiler: global.type = global import global._ import definitions.{ScalaPackage, JavaLangPackage, termMember, typeMember} import rootMirror.{RootClass, getClassIfDefined, getModuleIfDefined, getRequiredModule, getRequiredClass} - implicit class ReplTypeOps(tp: Type) { + private implicit class ReplTypeOps(tp: Type) { def orElse(other: => Type): Type = if (tp ne NoType) tp else other def andAlso(fn: Type => Type): Type = if (tp eq NoType) tp else fn(tp) } @@ -240,7 +283,8 @@ import org.apache.spark.util.Utils // TODO: If we try to make naming a lazy val, we run into big time // scalac unhappiness with what look like cycles. It has not been easy to // reduce, but name resolution clearly takes different paths. - object naming extends { + // NOTE: Exposed to repl package since used by SparkExprTyper + private[repl] object naming extends { val global: imain.global.type = imain.global } with Naming { // make sure we don't overwrite their unwisely named res3 etc. @@ -254,22 +298,43 @@ import org.apache.spark.util.Utils } import naming._ - object deconstruct extends { + // NOTE: Exposed to repl package since used by SparkILoop + private[repl] object deconstruct extends { val global: imain.global.type = imain.global } with StructuredTypeStrings - lazy val memberHandlers = new { + // NOTE: Exposed to repl package since used by SparkImports + private[repl] lazy val memberHandlers = new { val intp: imain.type = imain } with SparkMemberHandlers import memberHandlers._ - /** Temporarily be quiet */ + /** + * Suppresses overwriting print results during the operation. + * + * @param body The block to execute + * @tparam T The return type of the block + * + * @return The result from executing the block + */ + @DeveloperApi def beQuietDuring[T](body: => T): T = { val saved = printResults printResults = false try body finally printResults = saved } + + /** + * Completely masks all output during the operation (minus JVM standard + * out and error). + * + * @param operation The block to execute + * @tparam T The return type of the block + * + * @return The result from executing the block + */ + @DeveloperApi def beSilentDuring[T](operation: => T): T = { val saved = totalSilence totalSilence = true @@ -277,10 +342,10 @@ import org.apache.spark.util.Utils finally totalSilence = saved } - def quietRun[T](code: String) = beQuietDuring(interpret(code)) + // NOTE: Exposed to repl package since used by SparkILoop + private[repl] def quietRun[T](code: String) = beQuietDuring(interpret(code)) - - private def logAndDiscard[T](label: String, alt: => T): PartialFunction[Throwable, T] = { + private def logAndDiscard[T](label: String, alt: => T): PartialFunction[Throwable, T] = { case t: ControlThrowable => throw t case t: Throwable => logDebug(label + ": " + unwrap(t)) @@ -298,14 +363,44 @@ import org.apache.spark.util.Utils finally bindExceptions = true } + /** + * Contains the code (in string form) representing a wrapper around all + * code executed by this instance. + * + * @return The wrapper code as a string + */ + @DeveloperApi def executionWrapper = _executionWrapper + + /** + * Sets the code to use as a wrapper around all code executed by this + * instance. + * + * @param code The wrapper code as a string + */ + @DeveloperApi def setExecutionWrapper(code: String) = _executionWrapper = code + + /** + * Clears the code used as a wrapper around all code executed by + * this instance. + */ + @DeveloperApi def clearExecutionWrapper() = _executionWrapper = "" /** interpreter settings */ - lazy val isettings = new SparkISettings(this) + private lazy val isettings = new SparkISettings(this) - /** Instantiate a compiler. Overridable. */ + /** + * Instantiates a new compiler used by SparkIMain. Overridable to provide + * own instance of a compiler. + * + * @param settings The settings to provide the compiler + * @param reporter The reporter to use for compiler output + * + * @return The compiler as a Global + */ + @DeveloperApi protected def newCompiler(settings: Settings, reporter: Reporter): ReplGlobal = { settings.outputDirs setSingleOutput virtualDirectory settings.exposeEmptyPackage.value = true @@ -320,13 +415,14 @@ import org.apache.spark.util.Utils * @note Currently only supports jars, not directories * @param urls The list of items to add to the compile and runtime classpaths */ + @DeveloperApi def addUrlsToClassPath(urls: URL*): Unit = { new Run // Needed to force initialization of "something" to correctly load Scala classes from jars urls.foreach(_runtimeClassLoader.addNewUrl) // Add jars/classes to runtime for execution updateCompilerClassPath(urls: _*) // Add jars/classes to compile time for compiling } - protected def updateCompilerClassPath(urls: URL*): Unit = { + private def updateCompilerClassPath(urls: URL*): Unit = { require(!global.forMSIL) // Only support JavaPlatform val platform = global.platform.asInstanceOf[JavaPlatform] @@ -342,7 +438,7 @@ import org.apache.spark.util.Utils global.invalidateClassPathEntries(urls.map(_.getPath): _*) } - protected def mergeUrlsIntoClassPath(platform: JavaPlatform, urls: URL*): MergedClassPath[AbstractFile] = { + private def mergeUrlsIntoClassPath(platform: JavaPlatform, urls: URL*): MergedClassPath[AbstractFile] = { // Collect our new jars/directories and add them to the existing set of classpaths val allClassPaths = ( platform.classPath.asInstanceOf[MergedClassPath[AbstractFile]].entries ++ @@ -365,7 +461,13 @@ import org.apache.spark.util.Utils new MergedClassPath(allClassPaths, platform.classPath.context) } - /** Parent classloader. Overridable. */ + /** + * Represents the parent classloader used by this instance. Can be + * overridden to provide alternative classloader. + * + * @return The classloader used as the parent loader of this instance + */ + @DeveloperApi protected def parentClassLoader: ClassLoader = SparkHelper.explicitParentLoader(settings).getOrElse( this.getClass.getClassLoader() ) @@ -382,16 +484,18 @@ import org.apache.spark.util.Utils shadow the old ones, and old code objects refer to the old definitions. */ - def resetClassLoader() = { + private def resetClassLoader() = { logDebug("Setting new classloader: was " + _classLoader) _classLoader = null ensureClassLoader() } - final def ensureClassLoader() { + private final def ensureClassLoader() { if (_classLoader == null) _classLoader = makeClassLoader() } - def classLoader: AbstractFileClassLoader = { + + // NOTE: Exposed to repl package since used by SparkILoop + private[repl] def classLoader: AbstractFileClassLoader = { ensureClassLoader() _classLoader } @@ -418,27 +522,58 @@ import org.apache.spark.util.Utils _runtimeClassLoader }) - def getInterpreterClassLoader() = classLoader + private def getInterpreterClassLoader() = classLoader // Set the current Java "context" class loader to this interpreter's class loader - def setContextClassLoader() = classLoader.setAsContext() + // NOTE: Exposed to repl package since used by SparkILoopInit + private[repl] def setContextClassLoader() = classLoader.setAsContext() - /** Given a simple repl-defined name, returns the real name of - * the class representing it, e.g. for "Bippy" it may return - * {{{ - * $line19.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$Bippy - * }}} + /** + * Returns the real name of a class based on its repl-defined name. + * + * ==Example== + * Given a simple repl-defined name, returns the real name of + * the class representing it, e.g. for "Bippy" it may return + * {{{ + * $line19.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$Bippy + * }}} + * + * @param simpleName The repl-defined name whose real name to retrieve + * + * @return Some real name if the simple name exists, else None */ + @DeveloperApi def generatedName(simpleName: String): Option[String] = { if (simpleName endsWith nme.MODULE_SUFFIX_STRING) optFlatName(simpleName.init) map (_ + nme.MODULE_SUFFIX_STRING) else optFlatName(simpleName) } - def flatName(id: String) = optFlatName(id) getOrElse id - def optFlatName(id: String) = requestForIdent(id) map (_ fullFlatName id) + // NOTE: Exposed to repl package since used by SparkILoop + private[repl] def flatName(id: String) = optFlatName(id) getOrElse id + // NOTE: Exposed to repl package since used by SparkILoop + private[repl] def optFlatName(id: String) = requestForIdent(id) map (_ fullFlatName id) + + /** + * Retrieves all simple names contained in the current instance. + * + * @return A list of sorted names + */ + @DeveloperApi def allDefinedNames = definedNameMap.keys.toList.sorted - def pathToType(id: String): String = pathToName(newTypeName(id)) - def pathToTerm(id: String): String = pathToName(newTermName(id)) + + private def pathToType(id: String): String = pathToName(newTypeName(id)) + // NOTE: Exposed to repl package since used by SparkILoop + private[repl] def pathToTerm(id: String): String = pathToName(newTermName(id)) + + /** + * Retrieves the full code path to access the specified simple name + * content. + * + * @param name The simple name of the target whose path to determine + * + * @return The full path used to access the specified target (name) + */ + @DeveloperApi def pathToName(name: Name): String = { if (definedNameMap contains name) definedNameMap(name) fullPath name @@ -457,13 +592,13 @@ import org.apache.spark.util.Utils } /** Stubs for work in progress. */ - def handleTypeRedefinition(name: TypeName, old: Request, req: Request) = { + private def handleTypeRedefinition(name: TypeName, old: Request, req: Request) = { for (t1 <- old.simpleNameOfType(name) ; t2 <- req.simpleNameOfType(name)) { logDebug("Redefining type '%s'\n %s -> %s".format(name, t1, t2)) } } - def handleTermRedefinition(name: TermName, old: Request, req: Request) = { + private def handleTermRedefinition(name: TermName, old: Request, req: Request) = { for (t1 <- old.compilerTypeOf get name ; t2 <- req.compilerTypeOf get name) { // Printing the types here has a tendency to cause assertion errors, like // assertion failed: fatal: has owner value x, but a class owner is required @@ -473,7 +608,7 @@ import org.apache.spark.util.Utils } } - def recordRequest(req: Request) { + private def recordRequest(req: Request) { if (req == null || referencedNameMap == null) return @@ -504,12 +639,12 @@ import org.apache.spark.util.Utils } } - def replwarn(msg: => String) { + private def replwarn(msg: => String) { if (!settings.nowarnings.value) printMessage(msg) } - def isParseable(line: String): Boolean = { + private def isParseable(line: String): Boolean = { beSilentDuring { try parse(line) match { case Some(xs) => xs.nonEmpty // parses as-is @@ -522,22 +657,32 @@ import org.apache.spark.util.Utils } } - def compileSourcesKeepingRun(sources: SourceFile*) = { + private def compileSourcesKeepingRun(sources: SourceFile*) = { val run = new Run() reporter.reset() run compileSources sources.toList (!reporter.hasErrors, run) } - /** Compile an nsc SourceFile. Returns true if there are - * no compilation errors, or false otherwise. + /** + * Compiles specified source files. + * + * @param sources The sequence of source files to compile + * + * @return True if successful, otherwise false */ + @DeveloperApi def compileSources(sources: SourceFile*): Boolean = compileSourcesKeepingRun(sources: _*)._1 - /** Compile a string. Returns true if there are no - * compilation errors, or false otherwise. + /** + * Compiles a string of code. + * + * @param code The string of code to compile + * + * @return True if successful, otherwise false */ + @DeveloperApi def compileString(code: String): Boolean = compileSources(new BatchSourceFile("